Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
25b38212e9 nit 2025-01-19 09:50:35 -08:00
pablodanswer
3096b0b2a7 add linear check 2025-01-19 09:49:26 -08:00
159 changed files with 4323 additions and 4649 deletions

View File

@@ -11,4 +11,5 @@
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] I have included a link to a Linear ticket in my description.
- [ ] [Optional] Override Linear Check

View File

@@ -67,7 +67,6 @@ jobs:
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -60,8 +60,6 @@ jobs:
push: true
build-args: |
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -119,7 +119,7 @@ There are two editions of Onyx:
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:

View File

@@ -1,33 +0,0 @@
"""add passthrough auth to tool
Revision ID: f1ca58b2f2ec
Revises: c7bf5721733e
Create Date: 2024-03-19
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f1ca58b2f2ec"
down_revision: Union[str, None] = "c7bf5721733e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add passthrough_auth column to tool table with default value of False
op.add_column(
"tool",
sa.Column(
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
def downgrade() -> None:
# Remove passthrough_auth column from tool table
op.drop_column("tool", "passthrough_auth")

View File

@@ -1,72 +1,30 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
)
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_cloud_tasks_to_schedule = [
ee_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
},
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
ee_tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
ee_tasks_to_schedule = [
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
return base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@@ -4,20 +4,6 @@ import os
# Applicable for OIDC Auth
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
# Applicable for OIDC Auth, allows you to override the scopes that
# are requested from the OIDC provider. Currently used when passing
# over access tokens to tool calls and the tool needs more scopes
OIDC_SCOPE_OVERRIDE: list[str] | None = None
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
if _OIDC_SCOPE_OVERRIDE:
try:
OIDC_SCOPE_OVERRIDE = [
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
]
except Exception:
pass
# Applicable for SAML Auth
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"

View File

@@ -98,9 +98,10 @@ def get_page_of_chat_sessions(
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id)
select(ChatSession.id, ChatSession.time_created)
.filter(*conditions)
.order_by(desc(ChatSession.time_created), ChatSession.id)
.order_by(ChatSession.id, desc(ChatSession.time_created))
.distinct(ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
@@ -117,11 +118,7 @@ def get_page_of_chat_sessions(
ChatMessage.chat_message_feedbacks
),
)
.order_by(
desc(ChatSession.time_created),
ChatSession.id,
asc(ChatMessage.id), # Ensure chronological message order
)
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
)
return db_session.scalars(stmt).unique().all()

View File

@@ -1,9 +1,7 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
from httpx_oauth.clients.openid import OpenID
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
),
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,

View File

@@ -23,6 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
print("preferences_data", preferences_data)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(

View File

@@ -47,8 +47,3 @@ class UserUpdate(schemas.BaseUserUpdate):
Role updates are not allowed through the user update endpoint for security reasons
Role changes should be handled through a separate, admin-only process
"""
class AuthBackend(str, Enum):
REDIS = "redis"
POSTGRES = "postgres"

View File

@@ -33,8 +33,6 @@ from fastapi_users.authentication import AuthenticationBackend
from fastapi_users.authentication import CookieTransport
from fastapi_users.authentication import RedisStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
@@ -54,15 +52,13 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -78,7 +74,6 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
from onyx.db.auth import get_default_admin_user_emails
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
@@ -87,7 +82,6 @@ from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
@@ -215,7 +209,6 @@ def verify_email_domain(email: str) -> None:
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
@@ -587,14 +580,6 @@ def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
"""
A custom strategy that fetches the actual async Redis connection inside each method.
@@ -603,7 +588,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
def __init__(
self,
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
):
self.lifetime_seconds = lifetime_seconds
@@ -652,16 +637,9 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
await redis.delete(f"{self.key_prefix}{token}")
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
elif AUTH_BACKEND == AuthBackend.POSTGRES:
auth_backend = AuthenticationBackend(
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
)
else:
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
auth_backend = AuthenticationBackend(
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
)
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):

View File

@@ -23,7 +23,8 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
@@ -279,6 +280,51 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
client = get_vespa_http_client()
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
@@ -464,13 +510,3 @@ def reset_tenant_id(
) -> None:
"""Signal handler to reset tenant ID in context var after task ends."""
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if not wait_for_vespa_with_timeout():
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)

View File

@@ -81,7 +81,7 @@ class DynamicTenantScheduler(PersistentScheduler):
cloud_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": task.get("kwargs", {}),
"kwargs": {},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")

View File

@@ -62,7 +62,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -68,7 +68,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -63,7 +63,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -86,7 +86,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")

View File

@@ -17,234 +17,124 @@ from shared_configs.configs import MULTI_TENANT
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
# cloud specific tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
# remaining tasks are cloud generators for per tenant tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.LOW,
},
},
]
if LLM_MODEL_UPDATE_API_URL:
cloud_tasks_to_schedule.append(
# tasks that run in either self-hosted on cloud
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
if not MULTI_TENANT:
tasks_to_schedule.append(
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(hours=1), # Check every hour
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"priority": OnyxCeleryPriority.LOW,
},
}
)
# tasks that run in either self-hosted on cloud
tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
tasks_to_schedule.extend(
[
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
}
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule

View File

@@ -15,6 +15,7 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
@@ -25,12 +26,15 @@ from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempt
@@ -64,6 +68,10 @@ logger = setup_logger()
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
debug_tenants = {
"tenant_i-043470d740845ec56",
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
}
time_start = time.monotonic()
tasks_created = 0
@@ -115,6 +123,16 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# kick off index attempts
for cc_pair_id in cc_pair_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
@@ -123,12 +141,30 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
db_session
)
for search_settings_instance in search_settings_list:
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair search settings lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if redis_connector_index.fenced:
continue
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing get_connector_credential_pair_from_id: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -136,10 +172,28 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not cc_pair:
continue
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing get_last_attempt_for_cc_pair: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair should index: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
@@ -172,6 +226,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
cc_pair.id, None, db_session
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair try_creating_indexing_task: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
@@ -192,6 +255,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
tasks_created += 1
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# Fail any index attempts in the DB that don't have fences
@@ -201,7 +282,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
db_session, redis_client
)
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing after get unfenced lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
for attempt_id in unfenced_attempt_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced attempt id lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
@@ -219,6 +317,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
attempt.id, db_session, failure_reason=failure_reason
)
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing validate fences lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
@@ -567,9 +674,6 @@ def connector_indexing_proxy_task(
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
@@ -676,10 +780,67 @@ def connector_indexing_proxy_task(
)
continue
redis_connector_index.set_watchdog(False)
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
trail=False,
bind=True,
)
def cloud_check_for_indexing(self: Task) -> bool | None:
"""a lightweight task used to kick off individual check tasks for each tenant."""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
try:
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
self.app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(
tenant_id=tenant_id,
),
priority=OnyxCeleryPriority.HIGH,
expires=BEAT_EXPIRES_DEFAULT,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud indexing check")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_for_indexing - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_for_indexing finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -1,8 +1,6 @@
import json
import time
from collections.abc import Callable
from datetime import timedelta
from itertools import islice
from typing import Any
from celery import shared_task
@@ -12,17 +10,13 @@ from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
@@ -32,9 +26,7 @@ from onyx.db.models import DocumentSet
from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -185,10 +177,6 @@ def _build_connector_start_latency_metric(
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
task_logger.info(
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
)
return Metric(
key=metric_key,
name="connector_start_latency",
@@ -222,9 +210,6 @@ def _build_run_success_metrics(
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
task_logger.info(
f"Adding run success metric for index attempt {attempt.id} with status {attempt.status}"
)
metrics.append(
Metric(
key=metric_key,
@@ -245,29 +230,25 @@ def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Me
# Get all connector credential pairs
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
active_search_settings = get_active_search_settings(db_session)
metrics = []
for cc_pair, search_settings in zip(cc_pairs, active_search_settings):
for cc_pair in cc_pairs:
# Get all attempts in the last hour
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.search_settings_id == search_settings.id,
IndexAttempt.time_created >= one_hour_ago,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
)
if not recent_attempts:
continue
most_recent_attempt = recent_attempts[0]
most_recent_attempt = recent_attempts[0] if recent_attempts else None
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
if one_hour_ago > most_recent_attempt.time_created:
# if no metric to emit, skip
if most_recent_attempt is None:
continue
# Connector start latency
@@ -310,7 +291,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
f"{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
task_logger.debug(
f"Skipping metric for sync record {sync_record.id} "
"because it has already been emitted"
)
@@ -330,15 +311,11 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
if sync_speed is None:
task_logger.error(
f"Something went wrong with sync speed calculation. "
f"Sync record: {sync_record.id}, duration: {sync_duration_mins}, "
f"docs synced: {sync_record.num_docs_synced}"
"Something went wrong with sync speed calculation. "
f"Sync record: {sync_record.id}"
)
continue
task_logger.info(
f"Calculated sync speed for record {sync_record.id}: {sync_speed} docs/min"
)
metrics.append(
Metric(
key=metric_key,
@@ -357,7 +334,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
f":{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, start_latency_key):
task_logger.info(
task_logger.debug(
f"Skipping start latency metric for sync record {sync_record.id} "
"because it has already been emitted"
)
@@ -375,7 +352,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
)
else:
# Skip other sync types
task_logger.info(
task_logger.debug(
f"Skipping sync record {sync_record.id} "
f"with type {sync_record.sync_type} "
f"and id {sync_record.entity_id} "
@@ -394,15 +371,12 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
task_logger.info(
f"Calculated start latency for sync record {sync_record.id}: {start_latency} seconds"
)
if start_latency < 0:
task_logger.error(
f"Start latency is negative for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}. "
f"Sync start time: {sync_record.sync_start_time}, "
f"Entity last modified: {entity.time_last_modified_by_user}"
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
"This is likely because the entity was updated between the time the "
"time the sync finished and this job ran. Skipping."
)
continue
@@ -482,116 +456,3 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
lock_monitoring.release()
task_logger.info("Background monitoring task finished")
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
This check is expected to fail if a cloud alembic migration is currently running
across all tenants.
TODO: have the cloud migration script set an activity signal that this check
uses to know it doesn't make sense to run a check at the present time.
"""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_to_revision: dict[str, str | None] = {}
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
try:
# map each tenant_id to its revision
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
tenant_to_revision[tenant_id] = result_scalar
# get the total count of each revision
for k, v in tenant_to_revision.items():
if v is None:
continue
revision_counts[v] = revision_counts.get(v, 0) + 1
# get the revision with the most counts
sorted_revision_counts = sorted(
revision_counts.items(), key=lambda item: item[1], reverse=True
)
if len(sorted_revision_counts) == 0:
task_logger.error(
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
)
else:
top_revision, _ = sorted_revision_counts[0]
# build a list of out of date tenants
for k, v in tenant_to_revision.items():
if v == top_revision:
continue
out_of_date_tenants[k] = v
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud alembic check")
raise
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_alembic - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
if len(out_of_date_tenants) > 0:
task_logger.error(
f"Found out of date tenants: "
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
f"num_tenants={len(tenant_ids)} "
f"revision={top_revision}"
)
for k, v in islice(out_of_date_tenants.items(), 5):
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
else:
task_logger.info(
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -1,22 +1,15 @@
import time
from http import HTTPStatus
import httpx
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.document import fetch_chunk_count_for_document
@@ -25,13 +18,10 @@ from onyx.db.document import get_document_connector_count
from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
@@ -209,73 +199,3 @@ def document_by_cc_pair_cleanup_task(
return False
return True
@shared_task(
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
trail=False,
bind=True,
)
def cloud_beat_task_generator(
self: Task,
task_name: str,
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
try:
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
self.app.send_task(
task_name,
kwargs=dict(
tenant_id=tenant_id,
),
queue=queue,
priority=priority,
expires=expires,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
finally:
if not lock_beat.owned():
task_logger.error(
"cloud_beat_task_generator - Lock not owned on completion"
)
redis_lock_dump(lock_beat, redis_client)
else:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -735,7 +735,7 @@ def monitor_ccpair_indexing_taskset(
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"Connector indexing: could not parse composite_id from {fence_key}"
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
@@ -785,7 +785,6 @@ def monitor_ccpair_indexing_taskset(
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
@@ -831,7 +830,7 @@ def monitor_ccpair_indexing_taskset(
)
except Exception:
task_logger.exception(
"Connector indexing - Transient exception marking index attempt as failed: "
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -841,20 +840,6 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
@@ -873,13 +858,9 @@ def monitor_ccpair_indexing_taskset(
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes various long running tasks.
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
now. Should change that at some point.
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
For many tasks, the count is 0, that means all tasks finished and we should clean up.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
@@ -1064,8 +1045,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
@@ -1116,13 +1095,7 @@ def vespa_metadata_sync_task(
# r = get_redis_client(tenant_id=tenant_id)
# r.delete(redis_syncing_key)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
except Exception as ex:

View File

@@ -18,8 +18,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
@@ -31,16 +31,15 @@ def default_build_system_message(
prompt_config: PromptConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
if not tag_handled_prompt:
if not system_prompt:
return None
return SystemMessage(content=tag_handled_prompt)
system_msg = SystemMessage(content=system_prompt)
return system_msg
def default_build_user_message(
@@ -65,11 +64,8 @@ def default_build_user_message(
else user_query
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg

View File

@@ -21,9 +21,9 @@ from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import build_task_prompt_reminders
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from onyx.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
@@ -127,11 +127,10 @@ def build_citations_system_message(
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.include_citations:
system_prompt += REQUIRE_CITATION_STATEMENT
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt, prompt_config, add_additional_info_if_no_tag=True
)
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
return SystemMessage(content=tag_handled_prompt)
return SystemMessage(content=system_prompt)
def build_citations_user_message(

View File

@@ -9,8 +9,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
def _build_strong_llm_quotes_prompt(
@@ -39,11 +39,10 @@ def _build_strong_llm_quotes_prompt(
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
).strip()
tag_handled_prompt = handle_onyx_date_awareness(
full_prompt, prompt, add_additional_info_if_no_tag=True
)
if prompt.datetime_aware:
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
return HumanMessage(content=tag_handled_prompt)
return HumanMessage(content=full_prompt)
def build_quotes_user_message(

View File

@@ -3,7 +3,6 @@ import os
import urllib.parse
from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
@@ -56,12 +55,12 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value)
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS")
or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS")
or 86400 * 7
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
@@ -93,12 +92,6 @@ OAUTH_CLIENT_SECRET = (
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
# By default, this is set to match the Redis expiry time for consistency.
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"

View File

@@ -294,8 +294,7 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
class OnyxRedisSignals:
@@ -318,11 +317,6 @@ ONYX_CLOUD_TENANT_ID = "cloud"
class OnyxCeleryTask:
DEFAULT = "celery"
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
@@ -330,10 +324,8 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
@@ -351,6 +343,8 @@ class OnyxCeleryTask:
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15

View File

@@ -71,20 +71,10 @@ class AirtableConnector(LoadConnector):
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
@staticmethod
def _extract_field_values(
field_id: str,
field_info: Any,
field_type: str,
base_id: str,
table_id: str,
view_id: str | None,
record_id: str,
) -> list[tuple[str, str]]:
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
"""
Extract value(s) + links from a field regardless of its type.
Attachments are represented as multiple sections, and therefore
returned as a list of tuples (value, link).
Extract value(s) from a field regardless of its type.
Returns either a single string or list of strings for attachments.
"""
if field_info is None:
return []
@@ -95,11 +85,8 @@ class AirtableConnector(LoadConnector):
if field_type == "multipleRecordLinks":
return []
# default link to use for non-attachment fields
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
if field_type == "multipleAttachments":
attachment_texts: list[tuple[str, str]] = []
attachment_texts: list[str] = []
for attachment in field_info:
url = attachment.get("url")
filename = attachment.get("filename", "")
@@ -122,7 +109,6 @@ class AirtableConnector(LoadConnector):
if attachment_content:
try:
file_ext = get_file_ext(filename)
attachment_id = attachment["id"]
attachment_text = extract_file_text(
BytesIO(attachment_content),
filename,
@@ -130,20 +116,7 @@ class AirtableConnector(LoadConnector):
extension=file_ext,
)
if attachment_text:
# slightly nicer loading experience if we can specify the view ID
if view_id:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
else:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
attachment_texts.append(
(f"{filename}:\n{attachment_text}", attachment_link)
)
attachment_texts.append(f"{filename}:\n{attachment_text}")
except Exception as e:
logger.warning(
f"Failed to process attachment {filename}: {str(e)}"
@@ -158,12 +131,12 @@ class AirtableConnector(LoadConnector):
combined.append(collab_name)
if collab_email:
combined.append(f"({collab_email})")
return [(" ".join(combined) if combined else str(field_info), default_link)]
return [" ".join(combined) if combined else str(field_info)]
if isinstance(field_info, list):
return [(item, default_link) for item in field_info]
return [str(item) for item in field_info]
return [(str(field_info), default_link)]
return [str(field_info)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata."""
@@ -171,12 +144,10 @@ class AirtableConnector(LoadConnector):
def _process_field(
self,
field_id: str,
field_name: str,
field_info: Any,
field_type: str,
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
"""
@@ -194,21 +165,12 @@ class AirtableConnector(LoadConnector):
return [], {}
# Get the value(s) for the field
field_value_and_links = self._extract_field_values(
field_id=field_id,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
table_id=table_id,
view_id=view_id,
record_id=record_id,
)
if len(field_value_and_links) == 0:
field_values = self._get_field_value(field_info, field_type)
if len(field_values) == 0:
return [], {}
# Determine if it should be metadata or a section
if self._should_be_metadata(field_type):
field_values = [value for value, _ in field_value_and_links]
if len(field_values) > 1:
return [], {field_name: field_values}
return [], {field_name: field_values[0]}
@@ -216,7 +178,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
Section(
link=link,
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
text=(
f"{field_name}:\n"
"------------------------\n"
@@ -224,7 +186,7 @@ class AirtableConnector(LoadConnector):
"------------------------"
),
)
for text, link in field_value_and_links
for text in field_values
]
return sections, {}
@@ -257,7 +219,6 @@ class AirtableConnector(LoadConnector):
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
view_id = table_schema.views[0].id if table_schema.views else None
for field_schema in table_schema.fields:
field_name = field_schema.name
@@ -265,12 +226,10 @@ class AirtableConnector(LoadConnector):
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_id=field_schema.id,
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
view_id=view_id,
record_id=record_id,
)

View File

@@ -432,7 +432,7 @@ def get_paginated_index_attempts_for_cc_pair_id(
stmt = stmt.order_by(IndexAttempt.time_started.desc())
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())

View File

@@ -1430,8 +1430,6 @@ class Tool(Base):
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# whether to pass through the user's OAuth token as Authorization header
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table

View File

@@ -15,7 +15,6 @@ from onyx.db.models import User
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.tools.built_in_tools import get_search_tool
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -48,10 +47,6 @@ def create_slack_channel_persona(
) -> Persona:
"""NOTE: does not commit changes"""
search_tool = get_search_tool(db_session)
if search_tool is None:
raise ValueError("Search tool not found")
# create/update persona associated with the Slack channel
persona_name = _build_persona_name(channel_name)
default_prompt = get_default_prompt(db_session)
@@ -65,7 +60,6 @@ def create_slack_channel_persona(
llm_filter_extraction=enable_auto_filters,
recency_bias=RecencyBiasSetting.AUTO,
prompt_ids=[default_prompt.id],
tool_ids=[search_tool.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,

View File

@@ -38,7 +38,6 @@ def create_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool,
) -> Tool:
new_tool = Tool(
name=name,
@@ -49,7 +48,6 @@ def create_tool(
if custom_headers
else [],
user_id=user_id,
passthrough_auth=passthrough_auth,
)
db_session.add(new_tool)
db_session.commit()
@@ -64,7 +62,6 @@ def update_tool(
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool | None,
) -> Tool:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
@@ -82,8 +79,6 @@ def update_tool(
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
if passthrough_auth is not None:
tool.passthrough_auth = passthrough_auth
db_session.commit()
return tool

View File

@@ -1,5 +1,4 @@
import re
import time
from typing import cast
import httpx
@@ -8,10 +7,6 @@ from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code
# See here for reference: https://docs.vespa.ai/en/documents.html
@@ -74,37 +69,3 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=http2,
)
def wait_for_vespa_with_timeout(wait_interval: int = 5, wait_limit: int = 60) -> bool:
"""Waits for Vespa to become ready subject to a timeout.
Returns True if Vespa is ready, False otherwise."""
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
client = get_vespa_http_client()
response = client.get(f"{VESPA_APP_CONTAINER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return True
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > wait_limit:
logger.info(
f"Vespa: Readiness probe did not succeed within the timeout "
f"({wait_limit} seconds)."
)
return False
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit:.1f}"
)
time.sleep(wait_interval)

View File

@@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.store import PgRedisKVStore
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
def get_kv_store() -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore(tenant_id=tenant_id)
return PgRedisKVStore()

View File

@@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore):
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
@contextmanager
def _get_session(self) -> Iterator[Session]:
def get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(self.tenant_id):
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
@@ -66,7 +66,7 @@ class PgRedisKVStore(KeyValueStore):
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self._get_session() as session:
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
@@ -88,7 +88,7 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
with self._get_session() as session:
with self.get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
@@ -113,7 +113,7 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with self._get_session() as session:
with self.get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError

View File

@@ -275,22 +275,17 @@ class DefaultMultiLLM(LLM):
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
# Specifically pass in "vertex_credentials" / "vertex_location" as a
# model_kwarg to the completion call for vertex AI. More details here:
# Specifically pass in "vertex_credentials" as a model_kwarg to the
# completion call for vertex AI. More details here:
# https://docs.litellm.ai/docs/providers/vertex
vertex_credentials_key = "vertex_credentials"
vertex_location_key = "vertex_location"
for k, v in custom_config.items():
if model_provider == "vertex_ai":
if k == vertex_credentials_key:
model_kwargs[k] = v
continue
elif k == vertex_location_key:
model_kwargs[k] = v
continue
# for all values, set them as env variables
os.environ[k] = v
vertex_credentials = custom_config.get(vertex_credentials_key)
if vertex_credentials and model_provider == "vertex_ai":
model_kwargs[vertex_credentials_key] = vertex_credentials
else:
# standard case
for k, v in custom_config.items():
os.environ[k] = v
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})

View File

@@ -212,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid(tenant_id=None)
get_or_generate_uuid()
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:

View File

@@ -14,7 +14,6 @@ from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
@@ -123,9 +122,6 @@ class SlackbotHandler:
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
# Store Redis lock objects here so we can release them properly
self.redis_locks: Dict[str | None, Lock] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
@@ -163,15 +159,10 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
# After we finish acquiring and managing Slack bots,
# set the gauge to the number of active tenants (those with Slack bots).
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
logger.debug(
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
)
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
@@ -180,9 +171,7 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
@@ -190,21 +179,17 @@ class SlackbotHandler:
def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None:
"""
- If the tokens are missing or empty, close the socket client and remove them.
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are missing or empty, close the socket client and remove them.
# If the tokens are not set, we need to close the socket client and delete the tokens
# for the tenant and app
if not slack_bot_tokens:
logger.debug(
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
@@ -219,10 +204,9 @@ class SlackbotHandler:
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(
f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
)
else:
# Warm up the model if needed
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
@@ -233,168 +217,77 @@ class SlackbotHandler:
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
# Close any existing connection first
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
def acquire_tenants(self) -> None:
"""
- Attempt to acquire a Redis lock for each tenant.
- If acquired, check if that tenant actually has Slack bots.
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
tenant_ids = get_all_tenant_ids()
# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue
# Already acquired in a previous loop iteration?
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
# Respect max tenant limit per pod
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
# Acquire a Redis lock (non-blocking)
rlock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
pod_id = self.pod_id
acquired = redis_client.set(
OnyxRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
lock_acquired = rlock.acquire(blocking=False)
if not lock_acquired and not DEV_MODE:
logger.debug(
f"Another pod holds the lock for tenant {tenant_id}, skipping."
)
if not acquired and not DEV_MODE:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
if lock_acquired:
logger.debug(f"Acquired lock for tenant {tenant_id}.")
self.redis_locks[tenant_id] = rlock
else:
# DEV_MODE will skip the lock acquisition guard
logger.debug(
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
)
logger.debug(f"Acquired lock for tenant {tenant_id}")
# Now check if this tenant actually has Slack bots
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
bots: list[SlackBot] = []
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
)
if bots:
# Mark as active tenant
self.tenant_ids.add(tenant_id)
bots = fetch_slack_bots(db_session=db_session)
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
else:
# If no Slack bots, release lock immediately (unless in DEV_MODE)
if lock_acquired and not DEV_MODE:
rlock.release()
del self.redis_locks[tenant_id]
logger.debug(
f"No Slack bots for tenant {tenant_id}; lock released (if held)."
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# 2) Make sure tenants we're handling still have Slack bots
for tenant_id in list(self.tenant_ids):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id) as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass (and remove below)
bots = []
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if (tenant_id, bot.id) in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
del self.socket_clients[tenant_id, bot.id]
del self.slack_bot_tokens[tenant_id, bot.id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
bots = []
if not bots:
logger.info(
f"Tenant {tenant_id} no longer has Slack bots. Removing."
)
self._remove_tenant(tenant_id)
# NOTE: We release the lock here (in the same scope it was acquired)
if tenant_id in self.redis_locks and not DEV_MODE:
try:
self.redis_locks[tenant_id].release()
del self.redis_locks[tenant_id]
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Error releasing lock for tenant {tenant_id}: {e}"
)
else:
# Manage or reconnect Slack bot sockets
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _remove_tenant(self, tenant_id: str | None) -> None:
"""
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
if t_id == tenant_id:
asyncio.run(client.close())
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
)
# Remove from active set
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
@@ -422,7 +315,6 @@ class SlackbotHandler:
)
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@@ -430,7 +322,7 @@ class SlackbotHandler:
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@@ -448,19 +340,17 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants we currently hold
# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in list(self.tenant_ids):
if tenant_id in self.redis_locks:
try:
self.redis_locks[tenant_id].release()
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
finally:
del self.redis_locks[tenant_id]
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
# Wait for background threads to finish (with a timeout)
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)

View File

@@ -19,8 +19,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_DANSWER_DATETIME_REPLACEMENT_PAT = "[[CURRENT_DATETIME]]"
_BASIC_TIME_STR = "The current date is {datetime_info}."
MOST_BASIC_PROMPT = "You are a helpful AI assistant."
DANSWER_DATETIME_REPLACEMENT = "DANSWER_DATETIME_REPLACEMENT"
BASIC_TIME_STR = "The current date is {datetime_info}."
def get_current_llm_day_time(
@@ -37,36 +38,23 @@ def get_current_llm_day_time(
return f"{formatted_datetime}"
def build_date_time_string() -> str:
return ADDITIONAL_INFO.format(
datetime_info=_BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
)
def handle_onyx_date_awareness(
prompt_str: str,
prompt_config: PromptConfig,
add_additional_info_if_no_tag: bool = False,
) -> str:
"""
If there is a [[CURRENT_DATETIME]] tag, replace it with the current date and time no matter what.
If the prompt is datetime aware, and there are no [[CURRENT_DATETIME]] tags, add it to the prompt.
do nothing otherwise.
This can later be expanded to support other tags.
"""
if _DANSWER_DATETIME_REPLACEMENT_PAT in prompt_str:
def add_date_time_to_prompt(prompt_str: str) -> str:
if DANSWER_DATETIME_REPLACEMENT in prompt_str:
return prompt_str.replace(
_DANSWER_DATETIME_REPLACEMENT_PAT,
DANSWER_DATETIME_REPLACEMENT,
get_current_llm_day_time(full_sentence=False, include_day_of_week=True),
)
any_tag_present = any(
_DANSWER_DATETIME_REPLACEMENT_PAT in text
for text in [prompt_str, prompt_config.system_prompt, prompt_config.task_prompt]
)
if add_additional_info_if_no_tag and not any_tag_present:
return prompt_str + build_date_time_string()
return prompt_str
if prompt_str:
return prompt_str + ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
)
else:
return (
MOST_BASIC_PROMPT
+ " "
+ BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
)
def build_task_prompt_reminders(

View File

@@ -30,17 +30,10 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# there are gaps in time between states where we need some slack
# to correctly transition
# it's difficult to prevent
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
# used to signal that the watchdog is running
WATCHDOG_PREFIX = PREFIX + "_watchdog"
WATCHDOG_TTL = 300
def __init__(
self,
@@ -66,7 +59,6 @@ class RedisConnectorIndex:
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -118,24 +110,7 @@ class RedisConnectorIndex:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(
f"{self.terminate_key}_{celery_task_id}", 0, ex=self.TERMINATE_TTL
)
def set_watchdog(self, value: bool) -> None:
"""Signal the state of the watchdog."""
if not value:
self.redis.delete(self.watchdog_key)
return
self.redis.set(self.watchdog_key, 0, ex=self.WATCHDOG_TTL)
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -143,7 +118,7 @@ class RedisConnectorIndex:
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
self.redis.set(self.active_key, 0, ex=3600)
def active(self) -> bool:
if self.redis.exists(self.active_key):

View File

@@ -26,7 +26,6 @@ from onyx.db.index_attempt import mock_successful_index_attempt
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -34,6 +33,7 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import ConnectorBase
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@@ -218,11 +218,9 @@ def seed_initial_documents(
# Retries here because the index may take a few seconds to become ready
# as we just sent over the Vespa schema and there is a slight delay
if not wait_for_vespa_with_timeout():
logger.error("Vespa did not become ready within the timeout")
raise ValueError("Vespa failed to become ready within the timeout")
document_index.index(
index_with_retries = retry_builder(tries=15)(document_index.index)
index_with_retries(
chunks=chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt={},

View File

@@ -8,7 +8,7 @@ prompts:
# System Prompt (as shown in UI)
system: >
You are a question answering system that is constantly learning and improving.
The current date is [[CURRENT_DATETIME]].
The current date is DANSWER_DATETIME_REPLACEMENT.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
@@ -24,7 +24,7 @@ prompts:
If there are no relevant documents, refer to the chat history and your internal knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# If the [[CURRENT_DATETIME]] is set, the date/time is inserted there instead
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Prompts the LLM to include citations in the for [1], [2] etc.
@@ -51,7 +51,7 @@ prompts:
- name: "OnlyLLM"
description: "Chat directly with the LLM!"
system: >
You are a helpful AI assistant. The current date is [[CURRENT_DATETIME]]
You are a helpful AI assistant. The current date is DANSWER_DATETIME_REPLACEMENT
You give concise responses to very simple questions, but provide more thorough responses to
@@ -69,7 +69,7 @@ prompts:
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
The current date is [[CURRENT_DATETIME]].
The current date is DANSWER_DATETIME_REPLACEMENT.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
@@ -87,7 +87,7 @@ prompts:
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
Quote and cite relevant information from provided context based on the user query.
The current date is [[CURRENT_DATETIME]].
The current date is DANSWER_DATETIME_REPLACEMENT.
You only provide quotes that are EXACT substrings from provided documents!

View File

@@ -62,7 +62,7 @@ router = APIRouter(prefix="/manage")
@router.get("/admin/cc-pair/{cc_pair_id}/index-attempts")
def get_cc_pair_index_attempts(
cc_pair_id: int,
page_num: int = Query(0, ge=0),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
@@ -81,7 +81,7 @@ def get_cc_pair_index_attempts(
index_attempts = get_paginated_index_attempts_for_cc_pair_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
page=page_num,
page=page,
page_size=page_size,
)
return PaginatedReturn(

View File

@@ -1,8 +1,5 @@
import mimetypes
import os
import uuid
import zipfile
from io import BytesIO
from typing import cast
from fastapi import APIRouter
@@ -389,43 +386,10 @@ def upload_files(
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
# Skip directories and known macOS metadata entries
def should_process_file(file_path: str) -> bool:
normalized_path = os.path.normpath(file_path)
return not any(part.startswith(".") for part in normalized_path.split(os.sep))
try:
file_store = get_default_file_store(db_session)
deduped_file_paths = []
for file in files:
if file.content_type and file.content_type.startswith("application/zip"):
with zipfile.ZipFile(file.file, "r") as zf:
for file_info in zf.namelist():
if zf.getinfo(file_info).is_dir():
continue
if not should_process_file(file_info):
continue
sub_file_bytes = zf.read(file_info)
sub_file_name = os.path.join(str(uuid.uuid4()), file_info)
deduped_file_paths.append(sub_file_name)
mime_type, __ = mimetypes.guess_type(file_info)
if mime_type is None:
mime_type = "application/octet-stream"
file_store.save_file(
file_name=sub_file_name,
content=BytesIO(sub_file_bytes),
display_name=os.path.basename(file_info),
file_origin=FileOrigin.CONNECTOR,
file_type=mime_type,
)
continue
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
deduped_file_paths.append(file_path)
file_store.save_file(

View File

@@ -7,7 +7,6 @@ from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -192,7 +191,8 @@ def create_persona(
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
datetime_aware=persona_upsert_request.datetime_aware,
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
@@ -236,7 +236,8 @@ def update_persona(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
datetime_aware=persona_upsert_request.datetime_aware,
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
include_citations=persona_upsert_request.include_citations,
@@ -276,14 +277,8 @@ def create_label(
_: User | None = Depends(current_user),
) -> PersonaLabelResponse:
"""Create a new assistant label"""
try:
label_model = create_assistant_label(name=label.name, db_session=db)
return PersonaLabelResponse.from_model(label_model)
except IntegrityError:
raise HTTPException(
status_code=400,
detail=f"Label with name '{label.name}' already exists. Please choose a different name.",
)
label_model = create_assistant_label(name=label.name, db_session=db)
return PersonaLabelResponse.from_model(label_model)
@admin_router.patch("/label/{label_id}")

View File

@@ -60,7 +60,6 @@ class PersonaUpsertRequest(BaseModel):
description: str
system_prompt: str
task_prompt: str
datetime_aware: bool
document_set_ids: list[int]
num_chunks: float
include_citations: bool

View File

@@ -41,16 +41,6 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
raise HTTPException(status_code=400, detail=str(e))
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
if tool_data.passthrough_auth and tool_data.custom_headers:
for header in tool_data.custom_headers:
if header.key.lower() == "authorization":
raise HTTPException(
status_code=400,
detail="Cannot use passthrough auth with custom authorization headers",
)
@admin_router.post("/custom")
def create_custom_tool(
tool_data: CustomToolCreate,
@@ -58,7 +48,6 @@ def create_custom_tool(
user: User | None = Depends(current_admin_user),
) -> ToolSnapshot:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
tool = create_tool(
name=tool_data.name,
description=tool_data.description,
@@ -66,7 +55,6 @@ def create_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(tool)
@@ -80,7 +68,6 @@ def update_custom_tool(
) -> ToolSnapshot:
if tool_data.definition:
_validate_tool_definition(tool_data.definition)
_validate_auth_settings(tool_data)
updated_tool = update_tool(
tool_id=tool_id,
name=tool_data.name,
@@ -89,7 +76,6 @@ def update_custom_tool(
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
)
return ToolSnapshot.from_model(updated_tool)

View File

@@ -13,7 +13,6 @@ class ToolSnapshot(BaseModel):
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None
passthrough_auth: bool
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
@@ -25,7 +24,6 @@ class ToolSnapshot(BaseModel):
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
passthrough_auth=tool.passthrough_auth,
)
@@ -39,7 +37,6 @@ class CustomToolCreate(BaseModel):
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None
passthrough_auth: bool
class CustomToolUpdate(BaseModel):
@@ -47,4 +44,3 @@ class CustomToolUpdate(BaseModel):
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None
passthrough_auth: bool | None = None

View File

@@ -714,6 +714,7 @@ def update_user_pinned_assistants(
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.pinned_assistants = ordered_assistant_ids
print("ordered_assistant_ids", ordered_assistant_ids)
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:

View File

@@ -5,6 +5,7 @@ import os
import uuid
from collections.abc import Callable
from collections.abc import Generator
from typing import Tuple
from uuid import UUID
from fastapi import APIRouter
@@ -14,6 +15,7 @@ from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from PIL import Image
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -593,6 +595,21 @@ def seed_chat_from_slack(
"""File upload"""
def convert_to_jpeg(file: UploadFile) -> Tuple[io.BytesIO, str]:
try:
with Image.open(file.file) as img:
if img.mode != "RGB":
img = img.convert("RGB")
jpeg_io = io.BytesIO()
img.save(jpeg_io, format="JPEG", quality=85)
jpeg_io.seek(0)
return jpeg_io, "image/jpeg"
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Failed to convert image: {str(e)}"
)
@router.post("/file")
def upload_files_for_chat(
files: list[UploadFile],
@@ -628,9 +645,6 @@ def upload_files_for_chat(
)
for file in files:
if not file.content_type:
raise HTTPException(status_code=400, detail="File content type is required")
if file.content_type not in allowed_content_types:
if file.content_type in image_content_types:
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
@@ -662,27 +676,22 @@ def upload_files_for_chat(
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
file_type = (
ChatFileType.IMAGE
if file.content_type in image_content_types
else ChatFileType.CSV
if file.content_type in csv_content_types
else ChatFileType.DOC
if file.content_type in document_content_types
else ChatFileType.PLAIN_TEXT
)
if file_type == ChatFileType.IMAGE:
file_content = file.file
# NOTE: Image conversion to JPEG used to be enforced here.
# This was removed to:
# 1. Preserve original file content for downloads
# 2. Maintain transparency in formats like PNG
# 3. Ameliorate issue with file conversion
else:
if file.content_type in image_content_types:
file_type = ChatFileType.IMAGE
# Convert image to JPEG
file_content, new_content_type = convert_to_jpeg(file)
elif file.content_type in csv_content_types:
file_type = ChatFileType.CSV
file_content = io.BytesIO(file.file.read())
new_content_type = file.content_type
new_content_type = file.content_type or ""
elif file.content_type in document_content_types:
file_type = ChatFileType.DOC
file_content = io.BytesIO(file.file.read())
new_content_type = file.content_type or ""
else:
file_type = ChatFileType.PLAIN_TEXT
file_content = io.BytesIO(file.file.read())
new_content_type = file.content_type or ""
# store the file (now JPEG for images)
file_id = str(uuid.uuid4())

View File

@@ -104,10 +104,13 @@ def load_builtin_tools(db_session: Session) -> None:
logger.notice("All built-in tools are loaded/verified.")
def get_search_tool(db_session: Session) -> ToolDBModel | None:
def auto_add_search_tool_to_personas(db_session: Session) -> None:
"""
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
Automatically adds the SearchTool to all Persona objects in the database that have
`num_chunks` either unset or set to a value that isn't 0. This is done to migrate
Persona objects that were created before the concept of Tools were added.
"""
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
search_tool_id = next(
(
tool["in_code_tool_id"]
@@ -116,7 +119,6 @@ def get_search_tool(db_session: Session) -> ToolDBModel | None:
),
None,
)
if not search_tool_id:
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
@@ -124,18 +126,6 @@ def get_search_tool(db_session: Session) -> ToolDBModel | None:
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
).scalar_one_or_none()
return search_tool
def auto_add_search_tool_to_personas(db_session: Session) -> None:
"""
Automatically adds the SearchTool to all Persona objects in the database that have
`num_chunks` either unset or set to a value that isn't 0. This is done to migrate
Persona objects that were created before the concept of Tools were added.
"""
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
search_tool = get_search_tool(db_session)
if not search_tool:
raise RuntimeError("SearchTool not found in the database.")

View File

@@ -146,11 +146,6 @@ def construct_tools(
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
# Get user's OAuth token if available
user_oauth_token = None
if user and user.oauth_accounts:
user_oauth_token = user.oauth_accounts[0].access_token
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(
@@ -241,9 +236,6 @@ def construct_tools(
custom_tool_config.additional_headers or {}
)
),
user_oauth_token=(
user_oauth_token if db_tool_model.passthrough_auth else None
),
),
)

View File

@@ -80,12 +80,10 @@ class CustomTool(BaseTool):
method_spec: MethodSpec,
base_url: str,
custom_headers: list[HeaderItemDict] | None = None,
user_oauth_token: str | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
self._user_oauth_token = user_oauth_token
self._name = self._method_spec.name
self._description = self._method_spec.summary
@@ -93,20 +91,6 @@ class CustomTool(BaseTool):
header_list_to_header_dict(custom_headers) if custom_headers else {}
)
# Check for both Authorization header and OAuth token
has_auth_header = any(
key.lower() == "authorization" for key in self.headers.keys()
)
if has_auth_header and self._user_oauth_token:
logger.warning(
f"Tool '{self._name}' has both an Authorization "
"header and OAuth token set. This is likely a configuration "
"error as the OAuth token will override the custom header."
)
if self._user_oauth_token:
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
@property
def name(self) -> str:
return self._name
@@ -364,7 +348,6 @@ def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
custom_headers: list[HeaderItemDict] | None = None,
dynamic_schema_info: DynamicSchemaInfo | None = None,
user_oauth_token: str | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
# Process dynamic schema information
@@ -383,13 +366,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(
method_spec,
url,
custom_headers,
user_oauth_token=user_oauth_token,
)
for method_spec in method_specs
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]

View File

@@ -35,16 +35,12 @@ class LongTermLogger:
def _cleanup_old_files(self, category_path: Path) -> None:
try:
files = sorted(
[f for f in category_path.glob("*.json")],
[f for f in category_path.glob("*.json") if f.is_file()],
key=lambda x: x.stat().st_mtime, # Sort by modification time
reverse=True,
)
# Delete oldest files that exceed the limit
for file in files[self.max_files_per_category :]:
if not file.is_file():
logger.debug(f"File already deleted: {file}")
continue
try:
file.unlink()
except Exception as e:

View File

@@ -11,7 +11,7 @@ from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
@@ -41,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
def get_or_generate_uuid(tenant_id: str | None) -> str:
def get_or_generate_uuid(tenant_id: str | None = None) -> str:
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.
@@ -52,7 +52,7 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
if _CACHED_UUID is not None:
return _CACHED_UUID
kv_store = get_kv_store(tenant_id=tenant_id)
kv_store = get_kv_store()
try:
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
@@ -63,18 +63,18 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
return _CACHED_UUID
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
def _get_or_generate_instance_domain() -> str | None: #
global _CACHED_INSTANCE_DOMAIN
if _CACHED_INSTANCE_DOMAIN is not None:
return _CACHED_INSTANCE_DOMAIN
kv_store = get_kv_store(tenant_id=tenant_id)
kv_store = get_kv_store()
try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
@@ -94,16 +94,16 @@ def optional_telemetry(
if DISABLE_TELEMETRY:
return
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
try:
def telemetry_logic() -> None:
try:
customer_uuid = (
_get_or_generate_customer_id_mt(tenant_id)
_get_or_generate_customer_id_mt(
tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if MULTI_TENANT
else get_or_generate_uuid(tenant_id)
else get_or_generate_uuid()
)
payload = {
"data": data,
@@ -115,9 +115,7 @@ def optional_telemetry(
"is_cloud": MULTI_TENANT,
}
if ENTERPRISE_EDITION_ENABLED:
payload["instance_domain"] = _get_or_generate_instance_domain(
tenant_id
)
payload["instance_domain"] = _get_or_generate_instance_domain()
requests.post(
_DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},

View File

@@ -72,19 +72,6 @@ def run_jobs() -> None:
"--queues=connector_indexing",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"--queues=monitoring",
]
cmd_beat = [
"celery",
"-A",
@@ -110,13 +97,6 @@ def run_jobs() -> None:
cmd_worker_indexing, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_monitoring_process = subprocess.Popen(
cmd_worker_monitoring,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
@@ -134,23 +114,18 @@ def run_jobs() -> None:
worker_indexing_thread = threading.Thread(
target=monitor_process, args=("INDEX", worker_indexing_process)
)
worker_monitoring_thread = threading.Thread(
target=monitor_process, args=("MONITORING", worker_monitoring_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_primary_thread.start()
worker_light_thread.start()
worker_heavy_thread.start()
worker_indexing_thread.start()
worker_monitoring_thread.start()
beat_thread.start()
worker_primary_thread.join()
worker_light_thread.join()
worker_heavy_thread.join()
worker_indexing_thread.join()
worker_monitoring_thread.join()
beat_thread.join()

View File

@@ -1,30 +1,20 @@
# Tool to run helpful operations on Redis in production
# This is targeted for internal usage and may not have all the necessary parameters
# for general usage across custom deployments
import argparse
import json
import logging
import sys
import time
from logging import getLogger
from typing import cast
from uuid import UUID
from redis import Redis
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PASSWORD
from onyx.configs.app_configs import REDIS_PORT
from onyx.configs.app_configs import REDIS_SSL
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import RedisPool
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Tool to run helpful operations on Redis in production
# This is targeted for internal usage and may not have all the necessary parameters
# for general usage across custom deployments
# Configure the logger
logging.basicConfig(
@@ -39,18 +29,6 @@ SCAN_ITER_COUNT = 10000
BATCH_DEFAULT = 1000
def get_user_id(user_email: str) -> tuple[UUID, str]:
tenant_id = (
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)
with get_session_with_tenant(tenant_id) as session:
user = get_user_by_email(user_email, session)
if user is None:
raise ValueError(f"User not found for email: {user_email}")
return user.id, tenant_id
def onyx_redis(
command: str,
batch: int,
@@ -59,14 +37,13 @@ def onyx_redis(
port: int,
db: int,
password: str | None,
user_email: str | None = None,
) -> int:
pool = RedisPool.create_pool(
host=host,
port=port,
db=db,
password=password if password else "",
ssl=REDIS_SSL,
ssl=True,
ssl_cert_reqs="optional",
ssl_ca_certs=None,
)
@@ -95,25 +72,6 @@ def onyx_redis(
return purge_by_match_and_type(
"*connectorsync:vespa_syncing*", "string", batch, dry_run, r
)
elif command == "get_user_token":
if not user_email:
logger.error("You must specify --user-email with get_user_token")
return 1
token_key = get_user_token_from_redis(r, user_email)
if token_key:
print(f"Token key for user {user_email}: {token_key}")
return 0
else:
print(f"No token found for user {user_email}")
return 2
elif command == "delete_user_token":
if not user_email:
logger.error("You must specify --user-email with delete_user_token")
return 1
if delete_user_token_from_redis(r, user_email, dry_run):
return 0
else:
return 2
else:
pass
@@ -176,104 +134,6 @@ def purge_by_match_and_type(
return 0
def get_user_token_from_redis(r: Redis, user_email: str) -> str | None:
"""
Scans Redis keys for a user token that matches user_email or user_id fields.
Returns the token key if found, else None.
"""
user_id, tenant_id = get_user_id(user_email)
# Scan for keys matching the auth key prefix
auth_keys = r.scan_iter(f"{REDIS_AUTH_KEY_PREFIX}*", count=SCAN_ITER_COUNT)
matching_key = None
for key in auth_keys:
key_str = key.decode("utf-8")
jwt_token = r.get(key_str)
if not jwt_token:
continue
try:
jwt_token_str = (
jwt_token.decode("utf-8")
if isinstance(jwt_token, bytes)
else str(jwt_token)
)
if jwt_token_str.startswith("b'") and jwt_token_str.endswith("'"):
jwt_token_str = jwt_token_str[2:-1] # Remove b'' wrapper
jwt_data = json.loads(jwt_token_str)
if jwt_data.get("tenant_id") == tenant_id and str(
jwt_data.get("sub")
) == str(user_id):
matching_key = key_str
break
except json.JSONDecodeError:
logger.error(f"Failed to decode JSON for key: {key_str}")
except Exception as e:
logger.error(f"Error processing JWT for key: {key_str}. Error: {str(e)}")
if matching_key:
return matching_key[len(REDIS_AUTH_KEY_PREFIX) :]
return None
def delete_user_token_from_redis(
r: Redis, user_email: str, dry_run: bool = False
) -> bool:
"""
Scans Redis keys for a user token matching user_email and deletes it if found.
Returns True if something was deleted, otherwise False.
"""
user_id, tenant_id = get_user_id(user_email)
# Scan for keys matching the auth key prefix
auth_keys = r.scan_iter(f"{REDIS_AUTH_KEY_PREFIX}*", count=SCAN_ITER_COUNT)
matching_key = None
for key in auth_keys:
key_str = key.decode("utf-8")
jwt_token = r.get(key_str)
if not jwt_token:
continue
try:
jwt_token_str = (
jwt_token.decode("utf-8")
if isinstance(jwt_token, bytes)
else str(jwt_token)
)
if jwt_token_str.startswith("b'") and jwt_token_str.endswith("'"):
jwt_token_str = jwt_token_str[2:-1] # Remove b'' wrapper
jwt_data = json.loads(jwt_token_str)
if jwt_data.get("tenant_id") == tenant_id and str(
jwt_data.get("sub")
) == str(user_id):
matching_key = key_str
break
except json.JSONDecodeError:
logger.error(f"Failed to decode JSON for key: {key_str}")
except Exception as e:
logger.error(f"Error processing JWT for key: {key_str}. Error: {str(e)}")
if matching_key:
if dry_run:
logger.info(f"(DRY-RUN) Would delete token key: {matching_key}")
else:
r.delete(matching_key)
logger.info(f"Deleted token for user: {user_email}")
return True
else:
logger.info(f"No token found for user: {user_email}")
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Onyx Redis Manager")
parser.add_argument("--command", type=str, help="Operation to run", required=True)
@@ -325,13 +185,6 @@ if __name__ == "__main__":
required=False,
)
parser.add_argument(
"--user-email",
type=str,
help="User email for get or delete user token",
required=False,
)
args = parser.parse_args()
exitcode = onyx_redis(
command=args.command,
@@ -341,6 +194,5 @@ if __name__ == "__main__":
port=args.port,
db=args.db,
password=args.password,
user_email=args.user_email,
)
sys.exit(exitcode)

View File

@@ -1,269 +0,0 @@
"""
Vespa Debugging Tool!
Usage:
python vespa_debug_tool.py --action <action> [options]
Actions:
config : Print Vespa configuration
connect : Check Vespa connectivity
list_docs : List documents
search : Search documents
update : Update a document
delete : Delete a document
get_acls : Get document ACLs
Options:
--tenant-id : Tenant ID
--connector-id : Connector ID
--n : Number of documents (default 10)
--query : Search query
--doc-id : Document ID
--fields : Fields to update (JSON)
Example: (gets docs for a given tenant id and connector id)
python vespa_debug_tool.py --action list_docs --tenant-id my_tenant --connector-id 1 --n 5
"""
import argparse
import json
from typing import Any
from typing import Dict
from typing import List
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
# Print Vespa configuration URLs
def print_vespa_config() -> None:
print(f"Vespa Application Endpoint: {VESPA_APPLICATION_ENDPOINT}")
print(f"Vespa App Container URL: {VESPA_APP_CONTAINER_URL}")
print(f"Vespa Search Endpoint: {SEARCH_ENDPOINT}")
print(f"Vespa Document ID Endpoint: {DOCUMENT_ID_ENDPOINT}")
# Check connectivity to Vespa endpoints
def check_vespa_connectivity() -> None:
endpoints = [
f"{VESPA_APPLICATION_ENDPOINT}/ApplicationStatus",
f"{VESPA_APPLICATION_ENDPOINT}/tenant",
f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/",
f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default",
]
for endpoint in endpoints:
try:
with get_vespa_http_client() as client:
response = client.get(endpoint)
print(f"Successfully connected to Vespa at {endpoint}")
print(f"Status code: {response.status_code}")
print(f"Response: {response.text[:200]}...")
except Exception as e:
print(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
print("Vespa connectivity check completed.")
# Get info about the default Vespa application
def get_vespa_info() -> Dict[str, Any]:
url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default"
with get_vespa_http_client() as client:
response = client.get(url)
response.raise_for_status()
return response.json()
# Get index name for a tenant and connector pair
def get_index_name(tenant_id: str, connector_id: int) -> str:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(db_session, connector_id)
if not cc_pair:
raise ValueError(f"No connector found for id {connector_id}")
search_settings = get_current_search_settings(db_session)
return search_settings.index_name if search_settings else "public"
# Perform a Vespa query using YQL syntax
def query_vespa(yql: str) -> List[Dict[str, Any]]:
params = {
"yql": yql,
"timeout": "10s",
}
with get_vespa_http_client() as client:
response = client.get(SEARCH_ENDPOINT, params=params)
response.raise_for_status()
return response.json()["root"]["children"]
# Get first N documents
def get_first_n_documents(n: int = 10) -> List[Dict[str, Any]]:
yql = f"select * from sources * where true limit {n};"
return query_vespa(yql)
# Pretty-print a list of documents
def print_documents(documents: List[Dict[str, Any]]) -> None:
for doc in documents:
print(json.dumps(doc, indent=2))
print("-" * 80)
# Get and print documents for a specific tenant and connector
def get_documents_for_tenant_connector(
tenant_id: str, connector_id: int, n: int = 10
) -> None:
get_index_name(tenant_id, connector_id)
documents = get_first_n_documents(n)
print(f"First {n} documents for tenant {tenant_id}, connector {connector_id}:")
print_documents(documents)
# Search documents for a specific tenant and connector
def search_documents(
tenant_id: str, connector_id: int, query: str, n: int = 10
) -> None:
index_name = get_index_name(tenant_id, connector_id)
yql = f"select * from sources {index_name} where userInput(@query) limit {n};"
documents = query_vespa(yql)
print(f"Search results for query '{query}':")
print_documents(documents)
# Update a specific document
def update_document(
tenant_id: str, connector_id: int, doc_id: str, fields: Dict[str, Any]
) -> None:
index_name = get_index_name(tenant_id, connector_id)
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
update_request = {"fields": {k: {"assign": v} for k, v in fields.items()}}
with get_vespa_http_client() as client:
response = client.put(url, json=update_request)
response.raise_for_status()
print(f"Document {doc_id} updated successfully")
# Delete a specific document
def delete_document(tenant_id: str, connector_id: int, doc_id: str) -> None:
index_name = get_index_name(tenant_id, connector_id)
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
with get_vespa_http_client() as client:
response = client.delete(url)
response.raise_for_status()
print(f"Document {doc_id} deleted successfully")
# List documents from any source
def list_documents(n: int = 10) -> None:
yql = f"select * from sources * where true limit {n};"
url = f"{VESPA_APP_CONTAINER_URL}/search/"
params = {
"yql": yql,
"timeout": "10s",
}
try:
with get_vespa_http_client() as client:
response = client.get(url, params=params)
response.raise_for_status()
documents = response.json()["root"]["children"]
print(f"First {n} documents:")
print_documents(documents)
except Exception as e:
print(f"Failed to list documents: {str(e)}")
# Get and print ACLs for documents of a specific tenant and connector
def get_document_acls(tenant_id: str, connector_id: int, n: int = 10) -> None:
index_name = get_index_name(tenant_id, connector_id)
yql = f"select documentid, access_control_list from sources {index_name} where true limit {n};"
documents = query_vespa(yql)
print(f"ACLs for {n} documents from tenant {tenant_id}, connector {connector_id}:")
for doc in documents:
print(f"Document ID: {doc['fields']['documentid']}")
print(
f"ACL: {json.dumps(doc['fields'].get('access_control_list', {}), indent=2)}"
)
print("-" * 80)
def main() -> None:
parser = argparse.ArgumentParser(description="Vespa debugging tool")
parser.add_argument(
"--action",
choices=[
"config",
"connect",
"list_docs",
"search",
"update",
"delete",
"get_acls",
],
required=True,
help="Action to perform",
)
parser.add_argument(
"--tenant-id", help="Tenant ID (for update, delete, and get_acls actions)"
)
parser.add_argument(
"--connector-id",
type=int,
help="Connector ID (for update, delete, and get_acls actions)",
)
parser.add_argument(
"--n",
type=int,
default=10,
help="Number of documents to retrieve (for list_docs, search, update, and get_acls actions)",
)
parser.add_argument("--query", help="Search query (for search action)")
parser.add_argument("--doc-id", help="Document ID (for update and delete actions)")
parser.add_argument(
"--fields", help="Fields to update, in JSON format (for update action)"
)
args = parser.parse_args()
if args.action == "config":
print_vespa_config()
elif args.action == "connect":
check_vespa_connectivity()
elif args.action == "list_docs":
# If tenant_id and connector_id are provided, list docs for that tenant/connector.
# Otherwise, list documents from any source.
if args.tenant_id and args.connector_id:
get_documents_for_tenant_connector(
args.tenant_id, args.connector_id, args.n
)
else:
list_documents(args.n)
elif args.action == "search":
if not args.query:
parser.error("--query is required for search action")
search_documents(args.tenant_id, args.connector_id, args.query, args.n)
elif args.action == "update":
if not args.doc_id or not args.fields:
parser.error("--doc-id and --fields are required for update action")
fields = json.loads(args.fields)
update_document(args.tenant_id, args.connector_id, args.doc_id, fields)
elif args.action == "delete":
if not args.doc_id:
parser.error("--doc-id is required for delete action")
delete_document(args.tenant_id, args.connector_id, args.doc_id)
elif args.action == "get_acls":
if not args.tenant_id or args.connector_id is None:
parser.error(
"--tenant-id and --connector-id are required for get_acls action"
)
get_document_acls(args.tenant_id, args.connector_id, args.n)
if __name__ == "__main__":
main()

View File

@@ -45,7 +45,7 @@ def create_test_document(
submitted_by: str,
assignee: str,
days_since_status_change: int | None,
attachments: list[tuple[str, str]] | None = None,
attachments: list | None = None,
) -> Document:
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
sections = [
@@ -60,11 +60,11 @@ def create_test_document(
]
if attachments:
for attachment_text, attachment_link in attachments:
for attachment in attachments:
sections.append(
Section(
text=f"Attachment:\n------------------------\n{attachment_text}\n------------------------",
link=attachment_link,
text=f"Attachment:\n------------------------\n{attachment}\n------------------------",
link=f"{link_base}/{id}",
),
)
@@ -142,13 +142,7 @@ def test_airtable_connector_basic(
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
attachments=[
(
"Test.pdf:\ntesting!!!",
# hard code link for now
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
)
],
attachments=["Test.pdf:\ntesting!!!"],
),
]

View File

@@ -38,12 +38,6 @@ def get_credentials() -> dict[str, str]:
}
@pytest.mark.xfail(
reason=(
"Cannot get Zendesk developer account to ensure zendesk account does not "
"expire after 2 weeks"
)
)
@pytest.mark.parametrize(
"connector_fixture", ["zendesk_article_connector", "zendesk_ticket_connector"]
)
@@ -102,12 +96,6 @@ def test_zendesk_connector_basic(
)
@pytest.mark.xfail(
reason=(
"Cannot get Zendesk developer account to ensure zendesk account does not "
"expire after 2 weeks"
)
)
def test_zendesk_connector_slim(zendesk_article_connector: ZendeskConnector) -> None:
# Get full doc IDs
all_full_doc_ids = set()

View File

@@ -1,106 +0,0 @@
from datetime import datetime
from datetime import timedelta
from urllib.parse import urlencode
import requests
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexModelStatus
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.search_settings import get_current_search_settings
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
class IndexAttemptManager:
@staticmethod
def create_test_index_attempts(
num_attempts: int,
cc_pair_id: int,
from_beginning: bool = False,
status: IndexingStatus = IndexingStatus.SUCCESS,
new_docs_indexed: int = 10,
total_docs_indexed: int = 10,
docs_removed_from_index: int = 0,
error_msg: str | None = None,
base_time: datetime | None = None,
) -> list[DATestIndexAttempt]:
if base_time is None:
base_time = datetime.now()
attempts = []
with get_session_context_manager() as db_session:
# Get the current search settings
search_settings = get_current_search_settings(db_session)
if (
not search_settings
or search_settings.status != IndexModelStatus.PRESENT
):
raise ValueError("No current search settings found with PRESENT status")
for i in range(num_attempts):
time_created = base_time - timedelta(hours=i)
index_attempt = IndexAttempt(
connector_credential_pair_id=cc_pair_id,
from_beginning=from_beginning,
status=status,
new_docs_indexed=new_docs_indexed,
total_docs_indexed=total_docs_indexed,
docs_removed_from_index=docs_removed_from_index,
error_msg=error_msg,
time_created=time_created,
time_started=time_created,
time_updated=time_created,
search_settings_id=search_settings.id,
)
db_session.add(index_attempt)
db_session.flush() # To get the ID
attempts.append(
DATestIndexAttempt(
id=index_attempt.id,
status=index_attempt.status,
new_docs_indexed=index_attempt.new_docs_indexed,
total_docs_indexed=index_attempt.total_docs_indexed,
docs_removed_from_index=index_attempt.docs_removed_from_index,
error_msg=index_attempt.error_msg,
time_started=index_attempt.time_started,
time_updated=index_attempt.time_updated,
)
)
db_session.commit()
return attempts
@staticmethod
def get_index_attempt_page(
cc_pair_id: int,
page: int = 0,
page_size: int = 10,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[IndexAttemptSnapshot]:
query_params: dict[str, str | int] = {
"page_num": page,
"page_size": page_size,
}
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return PaginatedReturn(
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
total_items=data["total_items"],
)

View File

@@ -26,7 +26,6 @@ class PersonaManager:
is_public: bool = True,
llm_filter_extraction: bool = True,
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
datetime_aware: bool = False,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
@@ -47,7 +46,6 @@ class PersonaManager:
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
include_citations=include_citations,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
@@ -106,7 +104,6 @@ class PersonaManager:
is_public: bool | None = None,
llm_filter_extraction: bool | None = None,
recency_bias: RecencyBiasSetting | None = None,
datetime_aware: bool = False,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
@@ -124,7 +121,6 @@ class PersonaManager:
description=description or persona.description,
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
include_citations=include_citations,
num_chunks=num_chunks or persona.num_chunks,
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,

View File

@@ -1,5 +1,3 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
@@ -12,8 +10,6 @@ from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import IndexingStatus
from onyx.server.documents.models import InputType
"""
@@ -175,32 +171,3 @@ class DATestSettings(BaseModel):
gpu_enabled: bool | None = None
product_gating: DATestGatingType = DATestGatingType.NONE
anonymous_user_enabled: bool | None = None
@dataclass
class DATestIndexAttempt:
id: int
status: IndexingStatus | None
new_docs_indexed: int | None
total_docs_indexed: int | None
docs_removed_from_index: int | None
error_msg: str | None
time_started: datetime | None
time_updated: datetime | None
@classmethod
def from_index_attempt_snapshot(
cls, index_attempt: IndexAttemptSnapshot
) -> "DATestIndexAttempt":
return cls(
id=index_attempt.id,
status=index_attempt.status,
new_docs_indexed=index_attempt.new_docs_indexed,
total_docs_indexed=index_attempt.total_docs_indexed,
docs_removed_from_index=index_attempt.docs_removed_from_index,
error_msg=index_attempt.error_msg,
time_started=datetime.fromisoformat(index_attempt.time_started)
if index_attempt.time_started
else None,
time_updated=datetime.fromisoformat(index_attempt.time_updated),
)

View File

@@ -1,90 +0,0 @@
from datetime import datetime
from onyx.db.models import IndexingStatus
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
def _verify_index_attempt_pagination(
cc_pair_id: int,
index_attempts: list[DATestIndexAttempt],
page_size: int = 5,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_attempts: list[int] = []
last_time_started = None # Track the last time_started seen
for i in range(0, len(index_attempts), page_size):
paginated_result = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair_id,
page=(i // page_size),
page_size=page_size,
user_performing_action=user_performing_action,
)
# Verify that the total items is equal to the length of the index attempts list
assert paginated_result.total_items == len(index_attempts)
# Verify that the number of items in the page is equal to the page size
assert len(paginated_result.items) == min(page_size, len(index_attempts) - i)
# Verify time ordering within the page (descending order)
for attempt in paginated_result.items:
if last_time_started is not None:
assert (
attempt.time_started <= last_time_started
), "Index attempts not in descending time order"
last_time_started = attempt.time_started
# Add the retrieved index attempts to the list of retrieved attempts
retrieved_attempts.extend([attempt.id for attempt in paginated_result.items])
# Create a set of all the expected index attempt IDs
all_expected_attempts = set(attempt.id for attempt in index_attempts)
# Create a set of all the retrieved index attempt IDs
all_retrieved_attempts = set(retrieved_attempts)
# Verify that the set of retrieved attempts is equal to the set of expected attempts
assert all_expected_attempts == all_retrieved_attempts
def test_index_attempt_pagination(reset: None) -> None:
# Create an admin user to perform actions
user_performing_action: DATestUser = UserManager.create(
name="admin_performing_action",
is_first_user=True,
)
# Create a CC pair to attach index attempts to
cc_pair = CCPairManager.create_from_scratch(
user_performing_action=user_performing_action,
)
# Create 300 successful index attempts
base_time = datetime.now()
all_attempts = IndexAttemptManager.create_test_index_attempts(
num_attempts=300,
cc_pair_id=cc_pair.id,
status=IndexingStatus.SUCCESS,
base_time=base_time,
)
# Verify basic pagination with different page sizes
print("Verifying basic pagination with page size 5")
_verify_index_attempt_pagination(
cc_pair_id=cc_pair.id,
index_attempts=all_attempts,
page_size=5,
user_performing_action=user_performing_action,
)
# Test with a larger page size
print("Verifying pagination with page size 100")
_verify_index_attempt_pagination(
cc_pair_id=cc_pair.id,
index_attempts=all_attempts,
page_size=100,
user_performing_action=user_performing_action,
)

View File

@@ -3,7 +3,6 @@ import uuid
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
@@ -21,15 +20,14 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None
return next((p for p in providers if p["id"] == provider_id), None)
def test_create_llm_provider_without_display_model_names(reset: None) -> None:
def test_create_llm_provider_without_display_model_names(
admin_user: DATestUser,
) -> None:
"""Test creating an LLM provider without specifying
display_model_names and verify it's null in response"""
# Create admin user
admin_user = UserManager.create(name="admin_user")
# Create LLM provider without model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": str(uuid.uuid4()),
@@ -51,15 +49,12 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
assert provider_data["display_model_names"] is None
def test_update_llm_provider_model_names(reset: None) -> None:
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
"""Test updating an LLM provider's model_names"""
# Create admin user
admin_user = UserManager.create(name="admin_user")
# First create provider without model_names
name = str(uuid.uuid4())
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": name,
@@ -95,14 +90,11 @@ def test_update_llm_provider_model_names(reset: None) -> None:
assert provider_data["model_names"] == _DEFAULT_MODELS
def test_delete_llm_provider(reset: None) -> None:
def test_delete_llm_provider(admin_user: DATestUser) -> None:
"""Test deleting an LLM provider"""
# Create admin user
admin_user = UserManager.create(name="admin_user")
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": "test-provider-delete",

View File

@@ -22,7 +22,6 @@ const cspHeader = `
/** @type {import('next').NextConfig} */
const nextConfig = {
productionBrowserSourceMaps: false,
output: "standalone",
publicRuntimeConfig: {
version,
@@ -72,7 +71,7 @@ const nextConfig = {
// Sentry configuration for error monitoring:
// - Without SENTRY_AUTH_TOKEN and NEXT_PUBLIC_SENTRY_DSN: Sentry is completely disabled
// - With both configured: Capture errors and limited performance data
// - With both configured: Only unhandled errors are captured (no performance/session tracking)
// Determine if Sentry should be enabled
const sentryEnabled = Boolean(
@@ -88,11 +87,9 @@ const sentryWebpackPluginOptions = {
dryRun: !sentryEnabled, // Don't upload source maps when Sentry is disabled
sourceMaps: {
include: ["./.next"],
ignore: ["node_modules"],
validate: false,
urlPrefix: "~/_next",
stripPrefix: ["webpack://_N_E/"],
validate: true,
cleanArtifacts: true,
skip: !sentryEnabled,
},
};

1214
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -24,15 +24,13 @@
"@radix-ui/react-label": "^2.1.1",
"@radix-ui/react-popover": "^1.1.2",
"@radix-ui/react-radio-group": "^1.2.2",
"@radix-ui/react-scroll-area": "^1.2.2",
"@radix-ui/react-select": "^2.1.2",
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.1",
"@radix-ui/react-tooltip": "^1.1.3",
"@sentry/nextjs": "^8.50.0",
"@sentry/tracing": "^7.120.3",
"@sentry/nextjs": "^8.34.0",
"@stripe/stripe-js": "^4.6.0",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",

View File

@@ -3,10 +3,12 @@ import * as Sentry from "@sentry/nextjs";
if (process.env.NEXT_PUBLIC_SENTRY_DSN) {
Sentry.init({
dsn: process.env.NEXT_PUBLIC_SENTRY_DSN,
// Capture unhandled exceptions and performance data
enableTracing: true,
// Only capture unhandled exceptions
enableTracing: false,
integrations: [],
tracesSampleRate: 0.1,
tracesSampleRate: 0,
replaysSessionSampleRate: 0,
replaysOnErrorSampleRate: 0,
autoSessionTracking: false,
});
}

View File

@@ -1,6 +1,6 @@
"use client";
import React from "react";
import React, { useCallback } from "react";
import { Option } from "@/components/Dropdown";
import { generateRandomIconShape } from "@/lib/assistantIconUtils";
import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types";
@@ -35,7 +35,7 @@ import {
import Link from "next/link";
import { useRouter } from "next/navigation";
import { useEffect, useMemo, useState } from "react";
import { FiInfo } from "react-icons/fi";
import { FiInfo, FiRefreshCcw, FiUsers } from "react-icons/fi";
import * as Yup from "yup";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
@@ -60,11 +60,10 @@ import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList";
import { Switch, SwitchField } from "@/components/ui/switch";
import { Switch } from "@/components/ui/switch";
import { generateIdenticon } from "@/components/assistants/AssistantIcon";
import { BackButton } from "@/components/BackButton";
import { Checkbox, CheckboxField } from "@/components/ui/checkbox";
import { Checkbox } from "@/components/ui/checkbox";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { MinimalUserSnapshot } from "@/lib/types";
import { useUserGroups } from "@/lib/hooks";
@@ -73,13 +72,11 @@ import {
Option as DropdownOption,
} from "@/components/Dropdown";
import { SourceChip } from "@/app/chat/input/ChatInputBar";
import { TagIcon, UserIcon, XIcon } from "lucide-react";
import { TagIcon, UserIcon } from "lucide-react";
import { LLMSelector } from "@/components/llm/LLMSelector";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { DeletePersonaButton } from "./[id]/DeletePersonaButton";
import Title from "@/components/ui/title";
function findSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === "SearchTool");
@@ -131,8 +128,8 @@ export function AssistantEditor({
const router = useRouter();
const { popup, setPopup } = usePopup();
const { labels, refreshLabels, createLabel, updateLabel, deleteLabel } =
useLabels();
const { data, refreshLabels } = useLabels();
const labels = data || [];
const colorOptions = [
"#FF6FBF",
@@ -144,7 +141,11 @@ export function AssistantEditor({
"#6FFFFF",
];
const [showSearchTool, setShowSearchTool] = useState(false);
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const [hasEditedStarterMessage, setHasEditedStarterMessage] = useState(false);
const [showPersonaLabel, setShowPersonaLabel] = useState(!admin);
// state to persist across formik reformatting
const [defautIconColor, _setDeafultIconColor] = useState(
@@ -221,7 +222,6 @@ export function AssistantEditor({
const initialValues = {
name: existingPersona?.name ?? "",
description: existingPersona?.description ?? "",
datetime_aware: existingPrompt?.datetime_aware ?? false,
system_prompt: existingPrompt?.system_prompt ?? "",
task_prompt: existingPrompt?.task_prompt ?? "",
is_public: existingPersona?.is_public ?? defaultPublic,
@@ -330,10 +330,6 @@ export function AssistantEditor({
}));
};
if (!labels) {
return <></>;
}
return (
<div className="mx-auto max-w-4xl">
<style>
@@ -355,7 +351,7 @@ export function AssistantEditor({
entityName={labelToDelete.name}
onClose={() => setLabelToDelete(null)}
onSubmit={async () => {
const response = await deleteLabel(labelToDelete.id);
const response = await deletePersonaLabel(labelToDelete.id);
if (response?.ok) {
setPopup({
message: `Label deleted successfully`,
@@ -579,7 +575,7 @@ export function AssistantEditor({
return (
<Form className="w-full text-text-950 assistant-editor">
{/* Refresh starter messages when name or description changes */}
<p className="text-base font-normal text-2xl">
<p className="text-base font-normal !text-2xl">
{existingPersona ? (
<>
Edit assistant <b>{existingPersona.name}</b>
@@ -748,6 +744,97 @@ export function AssistantEditor({
className="[&_input]:placeholder:text-text-muted/50"
/>
<div className=" w-full max-w-4xl">
<Separator />
<div className="flex gap-x-2 items-center mt-4 ">
<div className="block font-medium text-sm">Labels</div>
</div>
<p
className="text-sm text-subtle"
style={{ color: "rgb(113, 114, 121)" }}
>
Select labels to categorize this assistant
</p>
<div className="mt-3">
<SearchMultiSelectDropdown
onCreateLabel={async (name: string) => {
await createPersonaLabel(name);
const currentLabels = await refreshLabels();
setTimeout(() => {
const newLabelId = currentLabels.find(
(l: { name: string }) => l.name === name
)?.id;
const updatedLabelIds = [
...values.label_ids,
newLabelId as number,
];
setFieldValue("label_ids", updatedLabelIds);
}, 300);
}}
options={Array.from(
new Set(labels.map((label) => label.name))
).map((name) => ({
name,
value: name,
}))}
onSelect={(selected) => {
const newLabelIds = [
...values.label_ids,
labels.find((l) => l.name === selected.value)
?.id as number,
];
setFieldValue("label_ids", newLabelIds);
}}
itemComponent={({ option }) => (
<div
className="flex items-center px-4 py-2.5 text-sm hover:bg-hover cursor-pointer"
onClick={() => {
const label = labels.find(
(l) => l.name === option.value
);
if (label) {
const isSelected = values.label_ids.includes(
label.id
);
const newLabelIds = isSelected
? values.label_ids.filter(
(id: number) => id !== label.id
)
: [...values.label_ids, label.id];
setFieldValue("label_ids", newLabelIds);
}
}}
>
<span className="text-sm font-medium leading-none">
{option.name}
</span>
</div>
)}
/>
<div className="mt-2 flex flex-wrap gap-2">
{values.label_ids.map((labelId: number) => {
const label = labels.find((l) => l.id === labelId);
return label ? (
<SourceChip
key={label.id}
onRemove={() => {
setFieldValue(
"label_ids",
values.label_ids.filter(
(id: number) => id !== label.id
)
);
}}
title={label.name}
icon={<TagIcon size={12} />}
/>
) : null;
})}
</div>
</div>
</div>
<Separator />
<TextFormField
@@ -785,9 +872,10 @@ export function AssistantEditor({
: ""
}`}
>
<SwitchField
<Switch
size="sm"
onCheckedChange={(checked) => {
setShowSearchTool(checked);
setFieldValue("num_chunks", null);
toggleToolInValues(searchTool.id);
}}
@@ -801,7 +889,8 @@ export function AssistantEditor({
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] text-sm rounded-lg p-1.5 text-white">
To use the Knowledge Action, you need to
have at least one Connector configured.
have at least one Connector-Credential
pair configured.
</p>
</TooltipContent>
)}
@@ -821,7 +910,7 @@ export function AssistantEditor({
)}
{ccPairs.length > 0 &&
searchTool &&
values.enabled_tools_map[searchTool.id] &&
showSearchTool &&
!(user?.role != "admin" && documentSets.length === 0) && (
<CollapsibleSection>
<div className="mt-2">
@@ -909,10 +998,14 @@ export function AssistantEditor({
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<CheckboxField
<Checkbox
size="sm"
id={`enabled_tools_map.${imageGenerationTool.id}`}
name={`enabled_tools_map.${imageGenerationTool.id}`}
checked={
values.enabled_tools_map[
imageGenerationTool.id
]
}
onCheckedChange={() => {
if (
currentLLMSupportsImageOutput &&
@@ -968,7 +1061,6 @@ export function AssistantEditor({
onCheckedChange={() => {
toggleToolInValues(internetSearchTool.id);
}}
name={`enabled_tools_map.${internetSearchTool.id}`}
/>
<div className="flex flex-col ml-2">
<span className="text-sm">
@@ -988,7 +1080,6 @@ export function AssistantEditor({
<React.Fragment key={tool.id}>
<div className="flex items-center content-start mb-2">
<Checkbox
size="sm"
id={`enabled_tools_map.${tool.id}`}
checked={values.enabled_tools_map[tool.id]}
onCheckedChange={() => {
@@ -1023,6 +1114,7 @@ export function AssistantEditor({
)
: null
}
userDefault={user?.preferences?.default_model || null}
requiresImageGeneration={
imageGenerationTool
? values.enabled_tools_map[imageGenerationTool.id]
@@ -1044,6 +1136,106 @@ export function AssistantEditor({
/>
</div>
{admin && labels && labels.length > 0 && (
<div className=" max-w-4xl">
<Separator />
<div className="flex gap-x-2 items-center ">
<div className="block font-medium text-sm">
Manage Labels
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
Manage existing labels or create new ones to group
similar assistants
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<SubLabel>Edit or delete existing labels</SubLabel>
<div className="grid grid-cols-1 gap-4">
{labels.map((label: PersonaLabel) => (
<div
key={label.id}
className="grid grid-cols-[1fr,2fr,auto] gap-4 items-end"
>
<TextFormField
fontSize="sm"
name={`editLabelName_${label.id}`}
label="Label Name"
value={
values.editLabelId === label.id
? values.editLabelName
: label.name
}
onChange={(e) => {
setFieldValue("editLabelId", label.id);
setFieldValue("editLabelName", e.target.value);
}}
/>
<div className="flex gap-2">
{values.editLabelId === label.id ? (
<>
<Button
onClick={async () => {
const updatedName =
values.editLabelName || label.name;
const response = await updatePersonaLabel(
label.id,
updatedName
);
if (response?.ok) {
setPopup({
message: `Label "${updatedName}" updated successfully`,
type: "success",
});
await refreshLabels();
setFieldValue("editLabelId", null);
setFieldValue("editLabelName", "");
setFieldValue("editLabelDescription", "");
} else {
setPopup({
message: `Failed to update label - ${await response.text()}`,
type: "error",
});
}
}}
>
Save
</Button>
<Button
variant="outline"
onClick={() => {
setFieldValue("editLabelId", null);
setFieldValue("editLabelName", "");
setFieldValue("editLabelDescription", "");
}}
>
Cancel
</Button>
</>
) : (
<>
<Button
variant="destructive"
onClick={async () => {
setLabelToDelete(label);
}}
>
Delete
</Button>
</>
)}
</div>
</div>
))}
</div>
</div>
)}
<Separator />
<AdvancedOptionsToggle
showAdvancedOptions={showAdvancedOptions}
@@ -1061,9 +1253,9 @@ export function AssistantEditor({
<div className="min-h-[100px]">
<div className="flex items-center mb-2">
<SwitchField
name="is_public"
<Switch
size="md"
checked={values.is_public}
onCheckedChange={(checked) => {
setFieldValue("is_public", checked);
if (checked) {
@@ -1205,124 +1397,19 @@ export function AssistantEditor({
autoStarterMessageEnabled={
autoStarterMessageEnabled
}
errors={errors}
isRefreshing={isRefreshing}
values={values.starter_messages}
arrayHelpers={arrayHelpers}
touchStarterMessages={() => {
setHasEditedStarterMessage(true);
}}
setFieldValue={setFieldValue}
/>
)}
/>
</div>
</div>
<div className=" w-full max-w-4xl">
<Separator />
<div className="flex gap-x-2 items-center mt-4 ">
<div className="block font-medium text-sm">Labels</div>
</div>
<p
className="text-sm text-subtle"
style={{ color: "rgb(113, 114, 121)" }}
>
Select labels to categorize this assistant
</p>
<div className="mt-3">
<SearchMultiSelectDropdown
onCreate={async (name: string) => {
await createLabel(name);
const currentLabels = await refreshLabels();
setTimeout(() => {
const newLabelId = currentLabels.find(
(l: { name: string }) => l.name === name
)?.id;
const updatedLabelIds = [
...values.label_ids,
newLabelId as number,
];
setFieldValue("label_ids", updatedLabelIds);
}, 300);
}}
options={Array.from(
new Set(labels.map((label) => label.name))
).map((name) => ({
name,
value: name,
}))}
onSelect={(selected) => {
const newLabelIds = [
...values.label_ids,
labels.find((l) => l.name === selected.value)
?.id as number,
];
setFieldValue("label_ids", newLabelIds);
}}
itemComponent={({ option }) => (
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-hover cursor-pointer border-b border-border last:border-b-0">
<div
className="flex-grow"
onClick={() => {
const label = labels.find(
(l) => l.name === option.value
);
if (label) {
const isSelected = values.label_ids.includes(
label.id
);
const newLabelIds = isSelected
? values.label_ids.filter(
(id: number) => id !== label.id
)
: [...values.label_ids, label.id];
setFieldValue("label_ids", newLabelIds);
}
}}
>
<span className="font-normal leading-none">
{option.name}
</span>
</div>
{admin && (
<button
onClick={(e) => {
e.stopPropagation();
const label = labels.find(
(l) => l.name === option.value
);
if (label) {
deleteLabel(label.id);
}
}}
className="ml-2 p-1 hover:bg-background-hover rounded"
>
<TrashIcon size={16} />
</button>
)}
</div>
)}
/>
<div className="mt-2 flex flex-wrap gap-2">
{values.label_ids.map((labelId: number) => {
const label = labels.find((l) => l.id === labelId);
return label ? (
<SourceChip
key={label.id}
onRemove={() => {
setFieldValue(
"label_ids",
values.label_ids.filter(
(id: number) => id !== label.id
)
);
}}
title={label.name}
icon={<TagIcon size={12} />}
/>
) : null;
})}
</div>
</div>
</div>
<Separator />
<div className="flex flex-col gap-y-4">
@@ -1348,6 +1435,7 @@ export function AssistantEditor({
small
subtext="Documents prior to this date will be ignored."
label="[Optional] Knowledge Cutoff Date"
value={values.search_start_date}
name="search_start_date"
/>
@@ -1373,17 +1461,6 @@ export function AssistantEditor({
</div>
<Separator />
<BooleanFormField
small
removeIndent
alignTop
name="datetime_aware"
label="Date and Time Aware"
subtext='Toggle this option to let the assistant know the current date and time (formatted like: "Thursday Jan 1, 1970 00:01"). To inject it in a specific place in the prompt, use the pattern [[CURRENT_DATETIME]]'
/>
<Separator />
<TextFormField
maxWidth="max-w-4xl"
name="task_prompt"
@@ -1397,14 +1474,6 @@ export function AssistantEditor({
explanationLink="https://docs.onyx.app/guides/assistants"
className="[&_textarea]:placeholder:text-text-muted/50"
/>
<div className="flex justify-end">
{existingPersona && (
<DeletePersonaButton
personaId={existingPersona!.id}
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
/>
)}
</div>
</>
)}

View File

@@ -1,182 +0,0 @@
"use client";
import React from "react";
import { Separator } from "@/components/ui/separator";
import { Button } from "@/components/ui/button";
import { SubLabel, TextFormField } from "@/components/admin/connectors/Field";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useLabels } from "@/lib/hooks";
import { PersonaLabel } from "./interfaces";
import { Form, Formik, FormikHelpers } from "formik";
import Title from "@/components/ui/title";
interface FormValues {
newLabelName: string;
editLabelId: number | null;
editLabelName: string;
}
export default function LabelManagement() {
const { labels, createLabel, updateLabel, deleteLabel } = useLabels();
const { setPopup, popup } = usePopup();
if (!labels) return null;
const handleSubmit = async (
values: FormValues,
{ setSubmitting, resetForm }: FormikHelpers<FormValues>
) => {
if (values.newLabelName.trim()) {
const response = await createLabel(values.newLabelName.trim());
if (response.ok) {
setPopup({
message: `Label "${values.newLabelName}" created successfully`,
type: "success",
});
resetForm();
} else {
const errorMsg = (await response.json()).detail;
setPopup({
message: `Failed to create label - ${errorMsg}`,
type: "error",
});
}
}
setSubmitting(false);
};
return (
<div>
{popup}
<div className="max-w-4xl">
<div className="flex gap-x-2 items-center">
<Title size="lg">Manage Labels</Title>
</div>
<Formik<FormValues>
initialValues={{
newLabelName: "",
editLabelId: null,
editLabelName: "",
}}
onSubmit={handleSubmit}
>
{({ values, setFieldValue, isSubmitting }) => (
<Form>
<div className="flex flex-col gap-4 mt-4 mb-6">
<div className="flex flex-col">
<Title className="text-lg">Create New Label</Title>
<SubLabel>
Labels are used to categorize personas. You can create a new
label by entering a name below.
</SubLabel>
</div>
<div className="max-w-3xl w-full justify-start flex gap-4 items-end">
<TextFormField
width="max-w-xs"
fontSize="sm"
name="newLabelName"
label="Label Name"
/>
<Button type="submit" disabled={isSubmitting}>
Create
</Button>
</div>
</div>
<div className="grid grid-cols-1 w-full gap-4">
<div className="flex flex-col">
<Title className="text-lg">Edit Labels</Title>
<SubLabel>
You can edit the name of a label by clicking on the label
name and entering a new name.
</SubLabel>
</div>
{labels.map((label: PersonaLabel) => (
<div key={label.id} className="flex w-full gap-4 items-end">
<TextFormField
fontSize="sm"
width="w-full max-w-xs"
name={`editLabelName_${label.id}`}
label="Label Name"
value={
values.editLabelId === label.id
? values.editLabelName
: label.name
}
onChange={(e) => {
setFieldValue("editLabelId", label.id);
setFieldValue("editLabelName", e.target.value);
}}
/>
<div className="flex gap-2">
{values.editLabelId === label.id ? (
<>
<Button
onClick={async () => {
const updatedName =
values.editLabelName || label.name;
const response = await updateLabel(
label.id,
updatedName
);
if (response.ok) {
setPopup({
message: `Label "${updatedName}" updated successfully`,
type: "success",
});
setFieldValue("editLabelId", null);
setFieldValue("editLabelName", "");
} else {
setPopup({
message: `Failed to update label - ${await response.text()}`,
type: "error",
});
}
}}
>
Save
</Button>
<Button
variant="outline"
onClick={() => {
setFieldValue("editLabelId", null);
setFieldValue("editLabelName", "");
}}
>
Cancel
</Button>
</>
) : (
<Button
variant="destructive"
onClick={async () => {
const response = await deleteLabel(label.id);
if (response.ok) {
setPopup({
message: `Label "${label.name}" deleted successfully`,
type: "success",
});
} else {
setPopup({
message: `Failed to delete label - ${await response.text()}`,
type: "error",
});
}
}}
>
Delete
</Button>
)}
</div>
</div>
))}
</div>
</Form>
)}
</Formik>
</div>
</div>
);
}

View File

@@ -102,6 +102,12 @@ export function PersonasTable() {
<div>
{popup}
<Text className="my-2">
Assistants will be displayed as options on the Chat / Search interfaces
in the order they are displayed below. Assistants marked as hidden will
not be displayed. Editable assistants are shown at the top.
</Text>
<DraggableTable
headers={["Name", "Description", "Type", "Is Visible", "Delete"]}
isAdmin={isAdmin}

View File

@@ -18,20 +18,25 @@ export default function StarterMessagesList({
values,
arrayHelpers,
isRefreshing,
touchStarterMessages,
debouncedRefreshPrompts,
autoStarterMessageEnabled,
errors,
setFieldValue,
}: {
values: StarterMessage[];
arrayHelpers: ArrayHelpers;
isRefreshing: boolean;
touchStarterMessages: () => void;
debouncedRefreshPrompts: () => void;
autoStarterMessageEnabled: boolean;
errors: any;
setFieldValue: any;
}) {
const [tooltipOpen, setTooltipOpen] = useState(false);
const handleInputChange = (index: number, value: string) => {
touchStarterMessages();
setFieldValue(`starter_messages.${index}.message`, value);
if (value && index === values.length - 1 && values.length < 4) {

View File

@@ -29,6 +29,12 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
defaultPublic={true}
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
/>
<Title>Delete Assistant</Title>
<DeletePersonaButton
personaId={values.existingPersona!.id}
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
/>
</CardSection>
</>
);

View File

@@ -6,7 +6,6 @@ interface PersonaUpsertRequest {
description: string;
system_prompt: string;
task_prompt: string;
datetime_aware: boolean;
document_set_ids: number[];
num_chunks: number | null;
include_citations: boolean;
@@ -37,7 +36,6 @@ export interface PersonaUpsertParameters {
system_prompt: string;
existing_prompt_id: number | null;
task_prompt: string;
datetime_aware: boolean;
document_set_ids: number[];
num_chunks: number | null;
include_citations: boolean;
@@ -107,7 +105,6 @@ function buildPersonaUpsertRequest(
is_public,
groups,
existing_prompt_id,
datetime_aware,
users,
tool_ids,
icon_color,
@@ -132,7 +129,6 @@ function buildPersonaUpsertRequest(
icon_shape,
remove_image,
search_start_date,
datetime_aware,
is_default_persona: creationRequest.is_default_persona ?? false,
recency_bias: "base_decay",
prompt_ids: existing_prompt_id ? [existing_prompt_id] : [],

View File

@@ -1,25 +1,32 @@
import { AssistantEditor } from "../AssistantEditor";
import { ErrorCallout } from "@/components/ErrorCallout";
import { RobotIcon } from "@/components/icons/icons";
import { BackButton } from "@/components/BackButton";
import CardSection from "@/components/admin/CardSection";
import { AdminPageTitle } from "@/components/admin/Title";
import { fetchAssistantEditorInfoSS } from "@/lib/assistants/fetchPersonaEditorInfoSS";
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
export default async function Page() {
const [values, error] = await fetchAssistantEditorInfoSS();
let body;
if (!values) {
return (
body = (
<ErrorCallout errorTitle="Something went wrong :(" errorMsg={error} />
);
} else {
return (
<div className="w-full">
body = (
<CardSection className="!border-none !bg-transparent !ring-none">
<AssistantEditor
{...values}
admin
defaultPublic={true}
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
/>
</div>
</CardSection>
);
}
return <div className="w-full">{body}</div>;
}

View File

@@ -1,4 +1,3 @@
"use client";
import { PersonasTable } from "./PersonaTable";
import { FiPlusSquare } from "react-icons/fi";
import Link from "next/link";
@@ -7,8 +6,6 @@ import Title from "@/components/ui/title";
import { Separator } from "@/components/ui/separator";
import { AssistantsIcon } from "@/components/icons/icons";
import { AdminPageTitle } from "@/components/admin/Title";
import LabelManagement from "./LabelManagement";
import { SubLabel } from "@/components/admin/connectors/Field";
export default async function Page() {
return (
@@ -46,12 +43,6 @@ export default async function Page() {
<Separator />
<Title>Existing Assistants</Title>
<SubLabel>
Assistants will be displayed as options on the Chat / Search
interfaces in the order they are displayed below. Assistants marked as
hidden will not be displayed. Editable assistants are shown at the
top.
</SubLabel>
<PersonasTable />
</div>
</div>

View File

@@ -11,13 +11,13 @@ import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { getErrorMsg } from "@/lib/fetchUtils";
import { ScoreSection } from "../ScoreEditor";
import { useRouter } from "next/navigation";
import { HorizontalFilters } from "@/components/search/filtering/Filters";
import { useFilters } from "@/lib/hooks";
import { buildFilters } from "@/lib/search/utils";
import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge";
import { DocumentSet } from "@/lib/types";
import { SourceIcon } from "@/components/SourceIcon";
import { Connector } from "@/lib/connectors/connectors";
import { HorizontalFilters } from "@/app/chat/shared_chat_search/Filters";
const DocumentDisplay = ({
document,
@@ -200,9 +200,6 @@ export function Explorer({
availableDocumentSets={documentSets}
existingSources={connectors.map((connector) => connector.source)}
availableTags={[]}
toggleFilters={() => {}}
filtersUntoggled={false}
tagsOnLeft={true}
/>
</div>
</div>

View File

@@ -59,7 +59,7 @@ function SummaryRow({
return (
<TableRow
onClick={onToggle}
className="border-border group hover:bg-background-settings-hover bg-background-sidebar py-4 rounded-sm !border cursor-pointer"
className="border-border bg-white py-4 rounded-sm !border cursor-pointer"
>
<TableCell>
<div className="text-xl flex items-center truncate ellipsis gap-x-2 font-semibold">
@@ -86,7 +86,7 @@ function SummaryRow({
<Tooltip>
<TooltipTrigger asChild>
<div className="flex items-center mt-1">
<div className="w-full bg-white rounded-full h-2 mr-2">
<div className="w-full bg-gray-200 rounded-full h-2 mr-2">
<div
className="bg-green-500 h-2 rounded-full"
style={{ width: `${activePercentage}%` }}

View File

@@ -1,21 +1,21 @@
"use client";
import useSWR from "swr";
import { useContext, useState } from "react";
import { useState } from "react";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { Button } from "@/components/ui/button";
import { NEXT_PUBLIC_WEB_DOMAIN } from "@/lib/constants";
import { ClipboardIcon } from "@/components/icons/icons";
import { Input } from "@/components/ui/input";
import { ThreeDotsLoader } from "@/components/Loading";
import { SettingsContext } from "@/components/settings/SettingsProvider";
export function AnonymousUserPath({
setPopup,
}: {
setPopup: (popup: PopupSpec) => void;
}) {
const settings = useContext(SettingsContext);
const fetcher = (url: string) => fetch(url).then((res) => res.json());
const [customPath, setCustomPath] = useState<string | null>(null);
const {
@@ -116,7 +116,7 @@ export function AnonymousUserPath({
<div className="flex flex-col gap-2 justify-center items-start">
<div className="w-full flex-grow flex items-center rounded-md shadow-sm">
<span className="inline-flex items-center rounded-l-md border border-r-0 border-gray-300 bg-gray-50 px-3 text-gray-500 sm:text-sm h-10">
{settings?.webDomain}/anonymous/
{NEXT_PUBLIC_WEB_DOMAIN}/anonymous/
</span>
<Input
type="text"
@@ -141,7 +141,7 @@ export function AnonymousUserPath({
className="h-10 px-4"
onClick={() => {
navigator.clipboard.writeText(
`${settings?.webDomain}/anonymous/${anonymousUserPath}`
`${NEXT_PUBLIC_WEB_DOMAIN}/anonymous/${anonymousUserPath}`
);
setPopup({
message: "Invite link copied!",

View File

@@ -62,5 +62,4 @@ export interface CombinedSettings {
customAnalyticsScript: string | null;
isMobile?: boolean;
webVersion: string | null;
webDomain: string | null;
}

View File

@@ -24,14 +24,6 @@ import debounce from "lodash/debounce";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import Link from "next/link";
import { Separator } from "@/components/ui/separator";
import { Checkbox } from "@/components/ui/checkbox";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useAuthType } from "@/lib/hooks";
function parseJsonWithTrailingCommas(jsonString: string) {
// Regular expression to remove trailing commas before } or ]
@@ -59,11 +51,7 @@ function ToolForm({
}: {
existingTool?: ToolSnapshot;
values: ToolFormValues;
setFieldValue: <T = any>(
field: string,
value: T,
shouldValidate?: boolean
) => void;
setFieldValue: (field: string, value: string) => void;
isSubmitting: boolean;
definitionErrorState: [
string | null,
@@ -77,9 +65,6 @@ function ToolForm({
const [definitionError, setDefinitionError] = definitionErrorState;
const [methodSpecs, setMethodSpecs] = methodSpecsState;
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const authType = useAuthType();
const isOAuthEnabled = authType === "oidc" || authType === "google_oauth";
const debouncedValidateDefinition = useCallback(
(definition: string) => {
const validateDefinition = async () => {
@@ -233,38 +218,43 @@ function ToolForm({
</p>
<FieldArray
name="customHeaders"
render={(arrayHelpers) => (
<div>
<div className="space-y-2">
{values.customHeaders.map(
(header: { key: string; value: string }, index: number) => (
<div
key={index}
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
>
<Field
name={`customHeaders.${index}.key`}
placeholder="Header Key"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Field
name={`customHeaders.${index}.value`}
placeholder="Header Value"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Button
type="button"
onClick={() => arrayHelpers.remove(index)}
variant="destructive"
size="sm"
className="transition-colors duration-200 hover:bg-red-600"
render={(arrayHelpers: ArrayHelpers) => (
<div className="space-y-4">
{values.customHeaders && values.customHeaders.length > 0 && (
<div className="space-y-3">
{values.customHeaders.map(
(
header: { key: string; value: string },
index: number
) => (
<div
key={index}
className="flex items-center space-x-2 bg-gray-50 p-3 rounded-lg shadow-sm"
>
Remove
</Button>
</div>
)
)}
</div>
<Field
name={`customHeaders.${index}.key`}
placeholder="Header Key"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Field
name={`customHeaders.${index}.value`}
placeholder="Header Value"
className="flex-1 p-2 border border-gray-300 rounded-md focus:ring-2 focus:ring-primary-500 focus:border-transparent"
/>
<Button
type="button"
onClick={() => arrayHelpers.remove(index)}
variant="destructive"
size="sm"
className="transition-colors duration-200 hover:bg-red-600"
>
Remove
</Button>
</div>
)
)}
</div>
)}
<Button
type="button"
@@ -278,75 +268,6 @@ function ToolForm({
</div>
)}
/>
<div className="mt-6">
<h3 className="text-xl font-bold mb-2 text-primary-600">
Authentication
</h3>
{isOAuthEnabled ? (
<div className="flex flex-col gap-y-2">
<div className="flex items-center space-x-2">
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<div
className={
values.customHeaders.some(
(header) =>
header.key.toLowerCase() === "authorization"
)
? "opacity-50"
: ""
}
>
<Checkbox
id="passthrough_auth"
size="sm"
checked={values.passthrough_auth}
disabled={values.customHeaders.some(
(header) =>
header.key.toLowerCase() === "authorization" &&
!values.passthrough_auth
)}
onCheckedChange={(checked) => {
setFieldValue("passthrough_auth", checked, true);
}}
/>
</div>
</TooltipTrigger>
{values.customHeaders.some(
(header) => header.key.toLowerCase() === "authorization"
) && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
Cannot enable OAuth passthrough when an
Authorization header is already set
</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<div className="flex flex-col">
<label
htmlFor="passthrough_auth"
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
>
Pass through user&apos;s OAuth token
</label>
<p className="text-xs text-subtle mt-1">
When enabled, the user&apos;s OAuth token will be passed
as the Authorization header for all API calls
</p>
</div>
</div>
</div>
) : (
<p className="text-sm text-subtle">
OAuth passthrough is only available when OIDC or OAuth
authentication is enabled
</p>
)}
</div>
</div>
)}
@@ -370,7 +291,6 @@ function ToolForm({
interface ToolFormValues {
definition: string;
customHeaders: { key: string; value: string }[];
passthrough_auth: boolean;
}
const ToolSchema = Yup.object().shape({
@@ -383,7 +303,6 @@ const ToolSchema = Yup.object().shape({
})
)
.default([]),
passthrough_auth: Yup.boolean().default(false),
});
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
@@ -407,27 +326,9 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
key: header.key,
value: header.value,
})) ?? [],
passthrough_auth: tool?.passthrough_auth ?? false,
}}
validationSchema={ToolSchema}
onSubmit={async (values: ToolFormValues) => {
const hasAuthHeader = values.customHeaders?.some(
(header) => header.key.toLowerCase() === "authorization"
);
if (hasAuthHeader && values.passthrough_auth) {
setPopup({
message:
"Cannot enable passthrough auth when Authorization " +
"headers are present. Please remove any Authorization " +
"headers first.",
type: "error",
});
console.log(
"Cannot enable passthrough auth when Authorization headers are present. Please remove any Authorization headers first."
);
return;
}
let definition: any;
try {
definition = parseJsonWithTrailingCommas(values.definition);
@@ -443,7 +344,6 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
description: description || "",
definition: definition,
custom_headers: values.customHeaders,
passthrough_auth: values.passthrough_auth,
};
let response;
if (tool) {

View File

@@ -5,7 +5,7 @@ import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Button } from "@/components/ui/button";
import InvitedUserTable from "@/components/admin/users/InvitedUserTable";
import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable";
import { SearchBar } from "@/components/search/SearchBar";
import { FiPlusSquare } from "react-icons/fi";
import { Modal } from "@/components/Modal";
import { ThreeDotsLoader } from "@/components/Loading";
@@ -18,7 +18,6 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import BulkAdd from "@/components/admin/users/BulkAdd";
import Text from "@/components/ui/text";
import { InvitedUserSnapshot } from "@/lib/types";
import { SearchBar } from "@/components/search/SearchBar";
const UsersTables = ({
q,

View File

@@ -117,6 +117,7 @@ export default function SidebarWrapper<T extends object>({
{" "}
<HistorySidebar
setShowAssistantsModal={setShowAssistantsModal}
assistants={assistants}
page={"chat"}
explicitlyUntoggle={explicitlyUntoggle}
ref={sidebarElementRef}
@@ -125,6 +126,7 @@ export default function SidebarWrapper<T extends object>({
existingChats={chatSessions}
currentChatSession={null}
folders={folders}
openedFolders={openedFolders}
/>
</div>
</div>

View File

@@ -1,12 +1,17 @@
import React, { useContext, useState, useRef, useLayoutEffect } from "react";
import React, { useState } from "react";
import { useRouter } from "next/navigation";
import {
FiMoreHorizontal,
FiShare2,
FiEye,
FiEyeOff,
FiTrash,
FiEdit,
FiHash,
FiBarChart,
FiLock,
FiUnlock,
FiSearch,
} from "react-icons/fi";
import { FaHashtag } from "react-icons/fa";
import {
@@ -21,38 +26,33 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import { useUser } from "@/components/user/UserProvider";
import { useAssistants } from "@/components/context/AssistantsContext";
import { checkUserOwnsAssistant } from "@/lib/assistants/utils";
import { toggleAssistantPinnedStatus } from "@/lib/assistants/pinnedAssistants";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { Button } from "@/components/ui/button";
import { PinnedIcon } from "@/components/icons/icons";
import {
deletePersona,
togglePersonaPublicStatus,
} from "@/app/admin/assistants/lib";
import { PencilIcon } from "lucide-react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { truncateString } from "@/lib/utils";
import { HammerIcon } from "lucide-react";
export const AssistantBadge = ({
text,
className,
maxLength,
}: {
text: string;
className?: string;
maxLength?: number;
}) => {
return (
<div
className={`h-4 px-1.5 py-1 text-[10px] flex-none bg-[#e6e3dd]/50 rounded-lg justify-center items-center gap-1 inline-flex ${className}`}
className={`h-4 px-1.5 py-1 text-[10px] bg-[#e6e3dd]/50 rounded-lg justify-center items-center gap-1 inline-flex ${className}`}
>
<div className="text-[#4a4a4a] font-normal leading-[8px]">
{maxLength ? truncateString(text, maxLength) : text}
</div>
<div className="text-[#4a4a4a] font-normal leading-[8px]">{text}</div>
</div>
);
};
@@ -62,9 +62,9 @@ const AssistantCard: React.FC<{
pinned: boolean;
closeModal: () => void;
}> = ({ persona, pinned, closeModal }) => {
const { user, toggleAssistantPinnedStatus } = useUser();
const { user, refreshUser } = useUser();
const router = useRouter();
const { refreshAssistants, pinnedAssistants } = useAssistants();
const { refreshAssistants } = useAssistants();
const isOwnedByUser = checkUserOwnsAssistant(user, persona);
@@ -72,8 +72,7 @@ const AssistantCard: React.FC<{
undefined
);
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const handleShare = () => setActivePopover("visibility");
const handleDelete = () => setActivePopover("delete");
const handleEdit = () => {
router.push(`/assistants/edit/${persona.id}`);
@@ -82,74 +81,33 @@ const AssistantCard: React.FC<{
const closePopover = () => setActivePopover(undefined);
const nameRef = useRef<HTMLHeadingElement>(null);
const hiddenNameRef = useRef<HTMLSpanElement>(null);
const [isNameTruncated, setIsNameTruncated] = useState(false);
useLayoutEffect(() => {
const checkTruncation = () => {
if (nameRef.current && hiddenNameRef.current) {
const visibleWidth = nameRef.current.offsetWidth;
const fullTextWidth = hiddenNameRef.current.offsetWidth;
setIsNameTruncated(fullTextWidth > visibleWidth);
}
};
checkTruncation();
window.addEventListener("resize", checkTruncation);
return () => window.removeEventListener("resize", checkTruncation);
}, [persona.name]);
return (
<div className="w-full p-2 overflow-visible pb-4 pt-3 bg-[#fefcf9] rounded shadow-[0px_0px_4px_0px_rgba(0,0,0,0.25)] flex flex-col">
<div className="w-full flex">
<div className="ml-2 flex-none mr-2 mt-1 w-10 h-10">
<div className="ml-2 mr-4 mt-1 w-8 h-8">
<AssistantIcon assistant={persona} size="large" />
</div>
<div className="flex-1 mt-1 flex flex-col">
<div className="flex justify-between items-start mb-1">
<div className="flex items-end gap-x-2 leading-none">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<h3
ref={nameRef}
className={` text-black line-clamp-1 break-all text-ellipsis leading-none font-semibold text-base lg-normal w-full overflow-hidden`}
>
{persona.name}
</h3>
</TooltipTrigger>
{isNameTruncated && (
<TooltipContent>{persona.name}</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<span
ref={hiddenNameRef}
className="absolute left-0 top-0 invisible whitespace-nowrap"
aria-hidden="true"
>
<h3 className="text-black leading-none font-semibold text-base lg-normal">
{persona.name}
</span>
</h3>
{persona.labels && persona.labels.length > 0 && (
<>
{persona.labels.slice(0, 2).map((label, index) => (
<AssistantBadge
key={index}
text={label.name}
maxLength={10}
/>
{persona.labels.slice(0, 3).map((label, index) => (
<AssistantBadge key={index} text={label.name} />
))}
{persona.labels.length > 2 && (
{persona.labels.length > 3 && (
<AssistantBadge
text={`+${persona.labels.length - 2} more`}
text={`+${persona.labels.length - 3} more`}
/>
)}
</>
)}
</div>
{isOwnedByUser && (
<div className="flex ml-2 items-center gap-x-2">
<div className="flex items-center gap-x-2">
<Popover
open={activePopover !== undefined}
onOpenChange={(open) =>
@@ -183,29 +141,41 @@ const AssistantCard: React.FC<{
<FiEdit size={12} className="inline mr-2" />
Edit
</button>
{isPaidEnterpriseFeaturesEnabled && (
<button
onClick={
isOwnedByUser
? () => {
router.push(
`/assistants/stats/${persona.id}`
);
closePopover();
}
: undefined
}
className={`w-full text-left items-center px-2 py-1 rounded ${
isOwnedByUser
? "hover:bg-neutral-100"
: "opacity-50 cursor-not-allowed"
}`}
disabled={!isOwnedByUser}
>
<FiBarChart size={12} className="inline mr-2" />
Stats
</button>
)}
{/*
<button
onClick={isOwnedByUser ? handleShare : undefined}
className={`w-full text-left flex items-center px-2 py-1 rounded ${
isOwnedByUser
? "hover:bg-neutral-100"
: "opacity-50 cursor-not-allowed"
}`}
disabled={!isOwnedByUser}
>
<FiShare2 size={12} className="inline mr-2" />
Share
</button> */}
<button
onClick={
isOwnedByUser
? () => {
router.push(
`/assistants/stats/${persona.id}`
);
closePopover();
}
: undefined
}
className={`w-full text-left items-center px-2 py-1 rounded ${
isOwnedByUser
? "hover:bg-neutral-100"
: "opacity-50 cursor-not-allowed"
}`}
disabled={!isOwnedByUser}
>
<FiBarChart size={12} className="inline mr-2" />
Stats
</button>
<button
onClick={isOwnedByUser ? handleDelete : undefined}
className={`w-full text-left items-center px-2 py-1 rounded ${
@@ -251,33 +221,33 @@ const AssistantCard: React.FC<{
)}
</div>
<p className="text-black font-[350] mt-0 text-sm line-clamp-2 h-[2.7em]">
<p className="text-black font-[350] mt-0 text-sm mb-1 line-clamp-2 h-[2.7em]">
{persona.description || "\u00A0"}
</p>
<div className="flex flex-col ">
<div className="my-1.5">
<p className="flex items-center text-black text-xs opacity-50">
{persona.owner?.email || persona.builtin_persona ? (
<>
<span className="truncate">
By {persona.owner?.email || "Onyx"}
</span>
{/* <div className="mb-1 mt-1">
<div className="flex items-center">
</div>
</div> */}
<span className="mx-2"></span>
<div className="my-1">
<span className="flex items-center text-black text-xs opacity-50">
{(persona.owner?.email || persona.builtin_persona) && "By "}
{persona.owner?.email || (persona.builtin_persona && "Onyx")}
{(persona.owner?.email || persona.builtin_persona) && (
<span className="mx-2"></span>
)}
{persona.tools.length > 0 ? (
<>
{persona.tools.length}
{" Action"}
{persona.tools.length !== 1 ? "s" : ""}
</>
) : null}
<span className="flex-none truncate">
{persona.tools.length > 0 ? (
<>
{persona.tools.length}
{" Action"}
{persona.tools.length !== 1 ? "s" : ""}
</>
) : (
"No Actions"
)}
</span>
) : (
"No Actions"
)}
<span className="mx-2"></span>
{persona.is_public ? (
<>
@@ -290,7 +260,17 @@ const AssistantCard: React.FC<{
Private
</>
)}
</p>
</span>
</div>
<div className="mb-1 flex flex-wrap">
{persona.document_sets.slice(0, 5).map((set, index) => (
<AssistantBadge
className="!text-base"
key={index}
text={set.name}
/>
))}
</div>
</div>
<div className="flex gap-2">
@@ -304,7 +284,7 @@ const AssistantCard: React.FC<{
}}
className="hover:bg-neutral-100 hover:text-text px-2 py-1 gap-x-1 rounded border border-black flex items-center"
>
<PencilIcon size={12} className="flex-none" />
<FaHashtag size={12} className="flex-none" />
<span className="text-xs">Start Chat</span>
</button>
</TooltipTrigger>
@@ -316,25 +296,20 @@ const AssistantCard: React.FC<{
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div
<button
onClick={async () => {
await toggleAssistantPinnedStatus(
pinnedAssistants.map((a) => a.id),
user?.preferences.pinned_assistants || [],
persona.id,
!pinned
);
await refreshUser();
}}
className="hover:bg-neutral-100 px-2 group cursor-pointer py-1 gap-x-1 relative rounded border border-black flex items-center w-[65px]"
className="hover:bg-neutral-100 px-2 py-1 gap-x-1 rounded border border-black flex items-center w-[65px]"
>
<PinnedIcon size={12} />
{!pinned ? (
<p className="absolute w-full left-0 group-hover:text-black w-full text-center transform text-xs">
Pin
</p>
) : (
<p className="text-xs group-hover:text-black">Unpin</p>
)}
</div>
<p className="text-xs">{pinned ? "Unpin" : "Pin"}</p>
</button>
</TooltipTrigger>
<TooltipContent>
{pinned ? "Remove from" : "Add to"} your pinned list

View File

@@ -1,14 +1,16 @@
"use client";
import React, { useMemo, useState } from "react";
import React, { useMemo, useState, useEffect } from "react";
import { Persona } from "@/app/admin/assistants/interfaces";
import { useRouter } from "next/navigation";
import { Modal } from "@/components/Modal";
import AssistantCard from "./AssistantCard";
import { useAssistants } from "@/components/context/AssistantsContext";
import { useUser } from "@/components/user/UserProvider";
import { FilterIcon } from "lucide-react";
import { checkUserOwnsAssistant } from "@/lib/assistants/checkOwnership";
import { useUser } from "@/components/user/UserProvider";
import { Button } from "@/components/ui/button";
import { useLabels } from "@/lib/hooks";
export const AssistantBadgeSelector = ({
text,
@@ -25,7 +27,7 @@ export const AssistantBadgeSelector = ({
selected
? "bg-neutral-900 text-white"
: "bg-transparent text-neutral-900"
} w-12 h-5 text-center px-1 py-0.5 rounded-lg cursor-pointer text-[12px] font-normal leading-[10px] border border-black justify-center items-center gap-1 inline-flex`}
} h-5 px-1 py-0.5 rounded-lg cursor-pointer text-[12px] font-normal leading-[10px] border border-black justify-center items-center gap-1 inline-flex`}
onClick={toggleFilter}
>
{text}
@@ -37,7 +39,6 @@ export enum AssistantFilter {
Pinned = "Pinned",
Public = "Public",
Private = "Private",
Mine = "Mine",
}
const useAssistantFilter = () => {
@@ -47,7 +48,6 @@ const useAssistantFilter = () => {
[AssistantFilter.Pinned]: false,
[AssistantFilter.Public]: false,
[AssistantFilter.Private]: false,
[AssistantFilter.Mine]: false,
});
const toggleAssistantFilter = (filter: AssistantFilter) => {
@@ -65,8 +65,11 @@ export default function AssistantModal({
}: {
hideModal: () => void;
}) {
const { assistants, pinnedAssistants } = useAssistants();
const { assistantFilters, toggleAssistantFilter } = useAssistantFilter();
const [showAllFeaturedAssistants, setShowAllFeaturedAssistants] =
useState(false);
const { assistants, visibleAssistants, pinnedAssistants } = useAssistants();
const { assistantFilters, toggleAssistantFilter, setAssistantFilters } =
useAssistantFilter();
const router = useRouter();
const { user } = useUser();
const [searchQuery, setSearchQuery] = useState("");
@@ -86,21 +89,16 @@ export default function AssistantModal({
!assistantFilters[AssistantFilter.Private] || !assistant.is_public;
const pinnedFilter =
!assistantFilters[AssistantFilter.Pinned] ||
(user?.preferences?.pinned_assistants?.includes(assistant.id) ?? false);
const mineFilter =
!assistantFilters[AssistantFilter.Mine] ||
assistants.map((a: Persona) => checkUserOwnsAssistant(user, a));
pinnedAssistants.map((a: Persona) => a.id).includes(assistant.id);
return (
(nameMatches || labelMatches) &&
publicFilter &&
privateFilter &&
pinnedFilter &&
mineFilter
pinnedFilter
);
});
}, [assistants, searchQuery, assistantFilters]);
}, [assistants, searchQuery, assistantFilters, pinnedAssistants]);
const featuredAssistants = [
...memoizedCurrentlyVisibleAssistants.filter(
@@ -124,10 +122,10 @@ export default function AssistantModal({
heightOverride={`${height}px`}
onOutsideClick={hideModal}
removeBottomPadding
className={`max-w-4xl max-h-[90vh] ${height} w-[95%] overflow-hidden`}
className={`max-w-4xl ${height} w-[95%] overflow-hidden`}
>
<div className="flex flex-col h-full">
<div className="flex bg-background flex-col sticky top-0 z-10">
<div className="flex flex-col sticky top-0 z-10">
<div className="flex px-2 justify-between items-center gap-x-2 mb-0">
<div className="h-12 w-full rounded-lg flex-col justify-center items-start gap-2.5 inline-flex">
<div className="h-12 rounded-md w-full shadow-[0px_0px_2px_0px_rgba(0,0,0,0.25)] border border-[#dcdad4] flex items-center px-3">
@@ -166,18 +164,16 @@ export default function AssistantModal({
</div>
</button>
</div>
<div className="px-2 flex py-4 items-center gap-x-2 flex-wrap">
<FilterIcon size={16} />
<div className="px-2 flex py-2 items-center gap-x-2 mb-2 flex-wrap">
<AssistantBadgeSelector
text="Pinned"
selected={assistantFilters[AssistantFilter.Pinned]}
toggleFilter={() => toggleAssistantFilter(AssistantFilter.Pinned)}
/>
<AssistantBadgeSelector
text="Mine"
selected={assistantFilters[AssistantFilter.Mine]}
toggleFilter={() => toggleAssistantFilter(AssistantFilter.Mine)}
text="Public"
selected={assistantFilters[AssistantFilter.Public]}
toggleFilter={() => toggleAssistantFilter(AssistantFilter.Public)}
/>
<AssistantBadgeSelector
text="Private"
@@ -186,11 +182,6 @@ export default function AssistantModal({
toggleAssistantFilter(AssistantFilter.Private)
}
/>
<AssistantBadgeSelector
text="Public"
selected={assistantFilters[AssistantFilter.Public]}
toggleFilter={() => toggleAssistantFilter(AssistantFilter.Public)}
/>
</div>
<div className="w-full border-t border-neutral-200" />
</div>
@@ -205,9 +196,7 @@ export default function AssistantModal({
featuredAssistants.map((assistant, index) => (
<div key={index}>
<AssistantCard
pinned={pinnedAssistants
.map((a) => a.id)
.includes(assistant.id)}
pinned={pinnedAssistants.includes(assistant)}
persona={assistant}
closeModal={hideModal}
/>
@@ -232,11 +221,7 @@ export default function AssistantModal({
.map((assistant, index) => (
<div key={index}>
<AssistantCard
pinned={
user?.preferences?.pinned_assistants?.includes(
assistant.id
) ?? false
}
pinned={pinnedAssistants.includes(assistant)}
persona={assistant}
closeModal={hideModal}
/>

View File

@@ -60,6 +60,7 @@ export function ChatBanner() {
`}
onMouseLeave={handleMouseLeave}
aria-expanded={isExpanded}
role="region"
>
<div className="text-emphasis text-sm w-full">
{/* Padding for consistent spacing */}

View File

@@ -25,6 +25,7 @@ import { HealthCheckBanner } from "@/components/health/healthcheck";
import {
buildChatUrl,
buildLatestMessageChain,
checkAnyAssistantHasSearch,
createChatSession,
deleteAllChatSessions,
getCitedDocumentsFromMessage,
@@ -304,7 +305,12 @@ export function ChatPage({
const [presentingDocument, setPresentingDocument] =
useState<OnyxDocument | null>(null);
const { recentAssistants, refreshRecentAssistants } = useAssistants();
const {
visibleAssistants: assistants,
recentAssistants,
assistants: allAssistants,
refreshRecentAssistants,
} = useAssistants();
const liveAssistant: Persona | undefined =
alternativeAssistant ||
@@ -1432,10 +1438,10 @@ export function ChatPage({
}
}
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
parentMessage =
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
const updateFn = (messages: Message[]) => {
const replacementsMap = regenerationRequest
@@ -1884,7 +1890,6 @@ export function ChatPage({
const [showDeleteAllModal, setShowDeleteAllModal] = useState(false);
const currentPersona = alternativeAssistant || liveAssistant;
useEffect(() => {
const handleSlackChatRedirect = async () => {
if (!slackChatId) return;
@@ -2059,9 +2064,9 @@ export function ChatPage({
{retrievalEnabled && documentSidebarToggled && settings?.isMobile && (
<div className="md:hidden">
<Modal
hideDividerForTitle
onOutsideClick={() => setDocumentSidebarToggled(false)}
title="Sources"
noPadding
noScroll
>
<DocumentResults
setPresentingDocument={setPresentingDocument}
@@ -2078,7 +2083,6 @@ export function ChatPage({
maxTokens={maxTokens}
initialWidth={400}
isOpen={true}
removeHeader
/>
</Modal>
</div>
@@ -2156,16 +2160,20 @@ export function ChatPage({
<div className="w-full relative">
<HistorySidebar
setShowAssistantsModal={setShowAssistantsModal}
assistants={assistants}
explicitlyUntoggle={explicitlyUntoggle}
stopGenerating={stopGenerating}
reset={() => setMessage("")}
page="chat"
ref={innerSidebarElementRef}
toggleSidebar={toggleSidebar}
toggled={toggledSidebar}
backgroundToggled={toggledSidebar || showHistorySidebar}
currentAssistantId={liveAssistant?.id}
existingChats={chatSessions}
currentChatSession={selectedChatSession}
folders={folders}
openedFolders={openedFolders}
removeToggle={removeToggle}
showShareModal={showShareModal}
showDeleteAllModal={() => setShowDeleteAllModal(true)}
@@ -2183,11 +2191,7 @@ export function ChatPage({
bg-opacity-80
duration-300
ease-in-out
${
documentSidebarToggled &&
!settings?.isMobile &&
"opacity-100 w-[350px]"
}`}
${documentSidebarToggled && "opacity-100 w-[350px]"}`}
></div>
</div>
</div>
@@ -2208,11 +2212,7 @@ export function ChatPage({
duration-300
ease-in-out
h-full
${
documentSidebarToggled && !settings?.isMobile
? "w-[400px]"
: "w-[0px]"
}
${documentSidebarToggled ? "w-[400px]" : "w-[0px]"}
`}
>
<DocumentResults
@@ -2229,13 +2229,12 @@ export function ChatPage({
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
initialWidth={400}
isOpen={documentSidebarToggled && !settings?.isMobile}
isOpen={documentSidebarToggled}
/>
</div>
<BlurBackground
visible={!untoggled && (showHistorySidebar || toggledSidebar)}
onClick={() => toggleSidebar()}
/>
<div
@@ -2254,9 +2253,7 @@ export function ChatPage({
? setSharingModalVisible
: undefined
}
documentSidebarToggled={
documentSidebarToggled && !settings?.isMobile
}
documentSidebarToggled={documentSidebarToggled}
toggleSidebar={toggleSidebar}
currentChatSession={selectedChatSession}
hideUserDropdown={user?.is_anonymous_user}
@@ -2322,7 +2319,7 @@ export function ChatPage({
currentSessionChatState == "input" &&
!loadingError &&
!submittedMessage && (
<div className="h-full w-[95%] mx-auto flex flex-col justify-center items-center">
<div className="h-full w-[95%] mx-auto flex flex-col justify-center items-center">
<ChatIntro selectedPersona={liveAssistant} />
<StarterMessages
@@ -2343,10 +2340,9 @@ export function ChatPage({
(settings?.enterpriseSettings
?.two_lines_for_chat_header
? "pt-20 "
: "pt-8 ")
: "pt-8") +
(hasPerformedInitialScroll ? "" : "invisible")
}
// NOTE: temporarily removing this to fix the scroll bug
// (hasPerformedInitialScroll ? "" : "invisible")
>
{(messageHistory.length < BUFFER_COUNT
? messageHistory
@@ -2477,6 +2473,12 @@ export function ChatPage({
setPresentingDocument
}
index={i}
selectedMessageForDocDisplay={
selectedMessageForDocDisplay
}
documentSelectionToggled={
documentSidebarToggled
}
continueGenerating={
i == messageHistory.length - 1 &&
currentCanContinue()
@@ -2596,6 +2598,19 @@ export function ChatPage({
}
: undefined
}
handleShowRetrieved={(messageNumber) => {
if (isShowingRetrieved) {
setSelectedMessageForDocDisplay(null);
} else {
if (messageNumber !== null) {
setSelectedMessageForDocDisplay(
messageNumber
);
} else {
setSelectedMessageForDocDisplay(-1);
}
}
}}
handleForceSearch={() => {
if (
previousMessage &&
@@ -2802,11 +2817,7 @@ export function ChatPage({
duration-300
ease-in-out
h-full
${
documentSidebarToggled && !settings?.isMobile
? "w-[350px]"
: "w-[0px]"
}
${documentSidebarToggled ? "w-[350px]" : "w-[0px]"}
`}
></div>
</div>

View File

@@ -4,7 +4,10 @@ import {
LlmOverride,
useLlmOverride,
} from "@/lib/hooks";
import { StringOrNumberOption } from "@/components/Dropdown";
import {
DefaultDropdownElement,
StringOrNumberOption,
} from "@/components/Dropdown";
import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
@@ -12,7 +15,7 @@ import { useState } from "react";
import { Hoverable } from "@/components/Hoverable";
import { Popover } from "@/components/popover/Popover";
import { IconType } from "react-icons";
import { FiRefreshCw, FiCheck } from "react-icons/fi";
import { FiRefreshCw } from "react-icons/fi";
export function RegenerateDropdown({
options,
@@ -40,33 +43,45 @@ export function RegenerateDropdown({
};
const Dropdown = (
<div className="overflow-y-auto py-2 min-w-fit bg-white dark:bg-gray-800 rounded-md shadow-lg">
<div className="mb-1 flex items-center justify-between px-4 pt-2">
<span className="text-sm text-text-500 dark:text-text-400">
Regenerate with
</span>
</div>
{options.map((option) => (
<div
key={option.value}
role="menuitem"
className={`flex items-center m-1.5 p-1.5 text-sm cursor-pointer focus-visible:outline-0 group relative hover:bg-gray-100 dark:hover:bg-gray-700 rounded-md my-0 px-3 mx-2 gap-2.5 py-3 !pr-3 ${
option.value === selected ? "bg-gray-100 dark:bg-gray-700" : ""
}`}
onClick={() => onSelect(option.value)}
>
<div className="flex grow items-center justify-between gap-2">
<div>
<div className="flex items-center gap-3">
<div>{getDisplayNameForModel(option.name)}</div>
</div>
</div>
</div>
{option.value === selected && (
<FiCheck className="text-blue-500 dark:text-blue-400" />
)}
</div>
))}
<div
className={`
border
border
rounded-lg
flex
flex-col
mx-2
bg-background
${maxHeight || "max-h-72"}
overflow-y-auto
overscroll-contain relative`}
>
<p
className="
sticky
top-0
flex
bg-background
font-medium
px-2
text-sm
py-1.5
"
>
Regenerate with
</p>
{options.map((option, ind) => {
const isSelected = option.value === selected;
return (
<DefaultDropdownElement
key={option.value}
name={getDisplayNameForModel(option.name)}
description={option.description}
onSelect={() => onSelect(option.value)}
isSelected={isSelected}
/>
);
})}
</div>
);

View File

@@ -79,9 +79,13 @@ export function ChatDocumentDisplay({
document.updated_at || Object.keys(document.metadata).length > 0;
return (
<div className="desktop:max-w-[400px] opacity-100 w-full">
<div
className={`desktop:max-w-[400px] opacity-100 ${
modal ? "w-[90vw]" : "w-full"
}`}
>
<div
className={`flex relative flex-col px-3 py-2.5 gap-0.5 rounded-xl my-1 ${
className={`flex relative flex-col px-3 py-2.5 gap-0.5 rounded-xl mx-2 my-1 ${
isSelected ? "bg-[#ebe7de]" : "bg- hover:bg-[#ebe7de]/80"
}`}
>

View File

@@ -26,7 +26,6 @@ interface DocumentResultsProps {
isSharedChat?: boolean;
modal: boolean;
setPresentingDocument: Dispatch<SetStateAction<OnyxDocument | null>>;
removeHeader?: boolean;
}
export const DocumentResults = forwardRef<HTMLDivElement, DocumentResultsProps>(
@@ -44,10 +43,10 @@ export const DocumentResults = forwardRef<HTMLDivElement, DocumentResultsProps>(
isSharedChat,
isOpen,
setPresentingDocument,
removeHeader,
},
ref: ForwardedRef<HTMLDivElement>
) => {
const { popup, setPopup } = usePopup();
const [delayedSelectedDocumentCount, setDelayedSelectedDocumentCount] =
useState(0);
@@ -99,29 +98,21 @@ export const DocumentResults = forwardRef<HTMLDivElement, DocumentResultsProps>(
}}
>
<div className="flex flex-col h-full">
{!removeHeader && (
<>
<div className="p-4 flex items-center justify-between gap-x-2">
<div className="flex items-center gap-x-2">
<h2 className="text-xl font-bold text-text-900">
Sources
</h2>
</div>
<button className="my-auto" onClick={closeSidebar}>
<XIcon size={16} />
</button>
</div>
<div className="border-b border-divider-history-sidebar-bar mx-3" />
</>
)}
<div className="overflow-y-auto h-fit mb-8 pb-8 sm:mx-0 flex-grow gap-y-0 default-scrollbar dark-scrollbar flex flex-col">
{popup}
<div className="p-4 flex items-center justify-between gap-x-2">
<div className="flex items-center gap-x-2">
{/* <SourcesIcon size={32} /> */}
<h2 className="text-xl font-bold text-text-900">Sources</h2>
</div>
<button className="my-auto" onClick={closeSidebar}>
<XIcon size={16} />
</button>
</div>
<div className="border-b border-divider-history-sidebar-bar mx-3" />
<div className="overflow-y-auto h-fit mb-8 pb-8 -mx-1 sm:mx-0 flex-grow gap-y-0 default-scrollbar dark-scrollbar flex flex-col">
{dedupedDocuments.length > 0 ? (
dedupedDocuments.map((document, ind) => (
<div
key={document.document_id}
className={`desktop:px-2 w-full`}
>
<div key={document.document_id} className="w-full">
<ChatDocumentDisplay
setPresentingDocument={setPresentingDocument}
closeSidebar={closeSidebar}

View File

@@ -34,7 +34,6 @@ interface FolderDropdownProps {
onDelete?: (folderId: number) => void;
onDrop?: (folderId: number, chatSessionId: string) => void;
children?: ReactNode;
index: number;
}
export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
@@ -47,7 +46,6 @@ export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
onEdit,
onDrop,
children,
index,
},
ref
) => {
@@ -157,123 +155,117 @@ export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
ref={setNodeRef}
style={style}
{...attributes}
className="overflow-visible mt-2 w-full"
className="overflow-visible w-full"
onDragOver={handleDragOver}
onDrop={handleDrop}
>
<div
className="sticky top-0 bg-background-sidebar z-10"
style={{ zIndex: 1000 - index }}
ref={ref}
className="flex overflow-visible items-center w-full text-[#6c6c6c] rounded-md p-1 relative"
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
>
<div
ref={ref}
className="flex overflow-visible items-center w-full text-text-darker rounded-md p-1 relative bg-background-sidebar sticky top-0"
style={{ zIndex: 10 - index }}
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
<button
className="flex overflow-hidden items-center flex-grow"
onClick={() => !isEditing && setIsOpen(!isOpen)}
{...(isEditing ? {} : listeners)}
>
<button
className="flex overflow-hidden items-center flex-grow"
onClick={() => !isEditing && setIsOpen(!isOpen)}
{...(isEditing ? {} : listeners)}
>
{isOpen ? (
<Caret size={16} className="mr-1" />
) : (
<Caret size={16} className="-rotate-90 mr-1" />
)}
{isEditing ? (
<div ref={editingRef} className="flex-grow z-[9999] relative">
<input
ref={inputRef}
type="text"
value={newFolderName}
onChange={(e) => setNewFolderName(e.target.value)}
className="text-sm font-medium bg-transparent outline-none w-full pb-1 border-b border-[#6c6c6c] transition-colors duration-200"
onKeyDown={(e) => {
if (e.key === "Enter") {
handleEdit();
}
}}
onClick={(e) => e.stopPropagation()}
/>
</div>
) : (
<div className="flex items-center">
<span className="text-sm font-[500]">
{folder.folder_name}
</span>
</div>
)}
</button>
{isHovered && !isEditing && folder.folder_id && (
<button
onClick={(e) => {
e.stopPropagation();
setIsEditing(true);
}}
className="ml-auto px-1"
>
<PencilIcon size={14} />
</button>
{isOpen ? (
<Caret size={16} className="mr-1" />
) : (
<Caret size={16} className="-rotate-90 mr-1" />
)}
{(isHovered || isDeletePopoverOpen) &&
!isEditing &&
folder.folder_id && (
<Popover
open={isDeletePopoverOpen}
onOpenChange={setIsDeletePopoverOpen}
content={
<button
onClick={(e) => {
e.stopPropagation();
handleDeleteClick();
}}
className="px-1"
>
<FiTrash2 size={14} />
</button>
}
popover={
<div className="p-3 w-64 border border-border rounded-lg bg-background z-50">
<p className="text-sm mb-3">
Are you sure you want to delete this folder?
</p>
<div className="flex justify-center gap-2">
<button
className="px-3 py-1 text-sm bg-gray-200 rounded"
onClick={handleCancelDelete}
>
Cancel
</button>
<button
className="px-3 py-1 text-sm bg-red-500 text-white rounded"
onClick={handleConfirmDelete}
>
Delete
</button>
</div>
</div>
}
requiresContentPadding
sideOffset={6}
{isEditing ? (
<div ref={editingRef} className="flex-grow z-[9999] relative">
<input
ref={inputRef}
type="text"
value={newFolderName}
onChange={(e) => setNewFolderName(e.target.value)}
className="text-sm font-medium bg-transparent outline-none w-full pb-1 border-b border-[#6c6c6c] transition-colors duration-200"
onKeyDown={(e) => {
if (e.key === "Enter") {
handleEdit();
}
}}
onClick={(e) => e.stopPropagation()}
/>
)}
{isEditing && (
<div className="flex -my-1 z-[9999]">
<button onClick={handleEdit} className="p-1">
<FiCheck size={14} />
</button>
<button onClick={() => setIsEditing(false)} className="p-1">
<FiX size={14} />
</button>
</div>
) : (
<div className="flex items-center">
<span className="text-sm font-medium">
{folder.folder_name}
</span>
</div>
)}
</div>
{isOpen && (
<div className="overflow-visible mr-3 ml-1 mt-1">{children}</div>
</button>
{isHovered && !isEditing && folder.folder_id && (
<button
onClick={(e) => {
e.stopPropagation();
setIsEditing(true);
}}
className="ml-auto px-1"
>
<PencilIcon size={14} />
</button>
)}
{(isHovered || isDeletePopoverOpen) &&
!isEditing &&
folder.folder_id && (
<Popover
open={isDeletePopoverOpen}
onOpenChange={setIsDeletePopoverOpen}
content={
<button
onClick={(e) => {
e.stopPropagation();
handleDeleteClick();
}}
className="px-1"
>
<FiTrash2 size={14} />
</button>
}
popover={
<div className="p-3 w-64 border border-border rounded-lg bg-background z-50">
<p className="text-sm mb-3">
Are you sure you want to delete this folder?
</p>
<div className="flex justify-center gap-2">
<button
className="px-3 py-1 text-sm bg-gray-200 rounded"
onClick={handleCancelDelete}
>
Cancel
</button>
<button
className="px-3 py-1 text-sm bg-red-500 text-white rounded"
onClick={handleConfirmDelete}
>
Delete
</button>
</div>
</div>
}
requiresContentPadding
sideOffset={6}
/>
)}
{isEditing && (
<div className="flex -my-1 z-[9999]">
<button onClick={handleEdit} className="p-1">
<FiCheck size={14} />
</button>
<button onClick={() => setIsEditing(false)} className="p-1">
<FiX size={14} />
</button>
</div>
)}
</div>
{isOpen && (
<div className="overflow-visible mr-3 ml-1 mt-1">{children}</div>
)}
</div>
);
}

View File

@@ -268,7 +268,7 @@ export default function InputPrompts() {
<Title>Prompt Shortcuts</Title>
<Text>
Manage and customize prompt shortcuts for your assistants. Use your
prompt shortcuts by starting a new message / in chat.
prompt shortcuts by starting a new message / in chat
</Text>
</div>
</div>

View File

@@ -328,7 +328,6 @@ export function ChatInputBar({
<div className="flex justify-center mx-auto">
<div
className="
max-w-full
w-[800px]
relative
desktop:px-4
@@ -506,10 +505,7 @@ export function ChatInputBar({
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder={`Message ${truncateString(
selectedAssistant.name,
70
)} assistant...`}
placeholder={`Message ${selectedAssistant.name} assistant...`}
value={message}
onKeyDown={(event) => {
if (
@@ -653,101 +649,92 @@ export function ChatInputBar({
</div>
)}
<div className="flex justify-between items-center overflow-hidden px-4 mb-2">
<div className="flex gap-x-1">
<ChatInputOption
flexPriority="stiff"
name="File"
Icon={FiPlusCircle}
onClick={() => {
const input = document.createElement("input");
input.type = "file";
input.multiple = true;
input.onchange = (event: any) => {
const files = Array.from(
event?.target?.files || []
) as File[];
if (files.length > 0) {
handleFileUpload(files);
}
};
input.click();
}}
tooltipContent={"Upload files"}
/>
<LLMPopover
llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager}
requiresImageGeneration={false}
currentAssistant={selectedAssistant}
/>
{retrievalEnabled && (
<FilterPopup
availableSources={availableSources}
availableDocumentSets={availableDocumentSets}
availableTags={availableTags}
filterManager={filterManager}
trigger={
<ChatInputOption
flexPriority="stiff"
name="Filters"
Icon={FiFilter}
tooltipContent="Filter your search"
/>
<div className="flex items-center space-x-1 mr-12 px-4 pb-2">
<ChatInputOption
flexPriority="stiff"
name="File"
Icon={FiPlusCircle}
onClick={() => {
const input = document.createElement("input");
input.type = "file";
input.multiple = true;
input.onchange = (event: any) => {
const files = Array.from(
event?.target?.files || []
) as File[];
if (files.length > 0) {
handleFileUpload(files);
}
/>
)}
</div>
<div className="flex my-auto">
};
input.click();
}}
tooltipContent={"Upload files"}
/>
<LLMPopover
llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager}
requiresImageGeneration={false}
currentAssistant={selectedAssistant}
/>
{retrievalEnabled && (
<FilterPopup
availableSources={availableSources}
availableDocumentSets={availableDocumentSets}
availableTags={availableTags}
filterManager={filterManager}
trigger={
<ChatInputOption
flexPriority="stiff"
name="Filters"
Icon={FiFilter}
tooltipContent="Filter your search"
/>
}
/>
)}
</div>
<div className="absolute bottom-2.5 mobile:right-4 desktop:right-10">
{chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading" ? (
<button
className={`cursor-pointer ${
chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading"
? chatState != "streaming"
? "bg-background-400"
: "bg-background-800"
: ""
} h-[28px] w-[28px] rounded-full`}
chatState != "streaming"
? "bg-background-400"
: "bg-background-800"
} h-[28px] w-[28px] rounded-full`}
onClick={stopGenerating}
disabled={chatState != "streaming"}
>
<StopGeneratingIcon
size={10}
className={`text-emphasis m-auto text-white flex-none
}`}
/>
</button>
) : (
<button
className="cursor-pointer"
onClick={() => {
if (
chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading"
) {
stopGenerating();
} else if (message) {
if (message) {
onSubmit();
}
}}
disabled={
(chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading") &&
chatState != "streaming"
}
disabled={chatState != "input"}
>
{chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading" ? (
<StopGeneratingIcon
size={10}
className="text-emphasis m-auto text-white flex-none"
/>
) : (
<SendIcon
size={26}
className={`text-emphasis text-white p-1 my-auto rounded-full ${
chatState == "input" && message
? "bg-submit-background"
: "bg-disabled-submit-background"
}`}
/>
)}
<SendIcon
size={26}
className={`text-emphasis text-white p-1 rounded-full ${
chatState == "input" && message
? "bg-submit-background"
: "bg-disabled-submit-background"
} `}
/>
</button>
</div>
)}
</div>
</div>
</div>

View File

@@ -15,7 +15,6 @@ interface ChatInputOptionProps {
tooltipContent?: React.ReactNode;
flexPriority?: "shrink" | "stiff" | "second";
toggle?: boolean;
minimize?: boolean;
}
export const ChatInputOption: React.FC<ChatInputOptionProps> = ({
@@ -27,10 +26,28 @@ export const ChatInputOption: React.FC<ChatInputOptionProps> = ({
tooltipContent,
toggle,
onClick,
minimize,
}) => {
const [isDropupVisible, setDropupVisible] = useState(false);
const [isTooltipVisible, setIsTooltipVisible] = useState(false);
const componentRef = useRef<HTMLButtonElement>(null);
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (
componentRef.current &&
!componentRef.current.contains(event.target as Node)
) {
setIsTooltipVisible(false);
setDropupVisible(false);
}
};
document.addEventListener("mousedown", handleClickOutside);
return () => {
document.removeEventListener("mousedown", handleClickOutside);
};
}, []);
return (
<TooltipProvider>
<Tooltip>
@@ -69,7 +86,7 @@ export const ChatInputOption: React.FC<ChatInputOptionProps> = ({
size={size}
className="h-4 w-4 my-auto text-[#4a4a4a] group-hover:text-text flex-none"
/>
<div className={`flex items-center ${minimize && "mobile:hidden"}`}>
<div className="flex items-center">
{name && (
<span className="text-sm text-[#4a4a4a] group-hover:text-text break-all line-clamp-1">
{name}

View File

@@ -1,4 +1,4 @@
import React, { useState } from "react";
import React from "react";
import {
Popover,
PopoverContent,
@@ -32,7 +32,6 @@ export default function LLMPopover({
requiresImageGeneration,
currentAssistant,
}: LLMPopoverProps) {
const [isOpen, setIsOpen] = useState(false);
const { llmOverride, updateLLMOverride, globalDefault } = llmOverrideManager;
const currentLlm = llmOverride.modelName || globalDefault.modelName;
@@ -82,11 +81,10 @@ export default function LLMPopover({
: null;
return (
<Popover open={isOpen} onOpenChange={setIsOpen}>
<Popover>
<PopoverTrigger asChild>
<button className="focus:outline-none">
<ChatInputOption
minimize
toggle
flexPriority="stiff"
name={getDisplayNameForModel(
@@ -121,10 +119,7 @@ export default function LLMPopover({
? "bg-gray-100 text-text"
: "text-text-darker"
}`}
onClick={() => {
updateLLMOverride(destructureValue(value));
setIsOpen(false);
}}
onClick={() => updateLLMOverride(destructureValue(value))}
>
{icon({ size: 16, className: "flex-none my-auto " })}
<span className="line-clamp-1 ">
@@ -137,7 +132,14 @@ export default function LLMPopover({
(assistant)
</span>
);
} else if (globalDefault.modelName === name) {
return (
<span className="flex-none ml-auto text-xs">
(user default)
</span>
);
}
return null;
})()}
</button>
);

View File

@@ -94,7 +94,7 @@ export function SimplifiedChatInputBar({
rounded-lg
relative
text-text-chatbar
bg-white
bg-background-chatbar
[&:has(textarea:focus)]::ring-1
[&:has(textarea:focus)]::ring-black
"
@@ -146,7 +146,7 @@ export function SimplifiedChatInputBar({
resize-none
rounded-lg
border-0
bg-white
bg-background-chatbar
placeholder:text-text-chatbar-subtle
${
textAreaRef.current &&

View File

@@ -363,8 +363,8 @@ export function groupSessionsByDateRange(chatSessions: ChatSession[]) {
const groups: Record<string, ChatSession[]> = {
Today: [],
"Previous 7 Days": [],
"Previous 30 days": [],
"Over 30 days": [],
"Previous 30 Days": [],
"Over 30 days ago": [],
};
chatSessions.forEach((chatSession) => {
@@ -378,9 +378,9 @@ export function groupSessionsByDateRange(chatSessions: ChatSession[]) {
} else if (diffDays <= 7) {
groups["Previous 7 Days"].push(chatSession);
} else if (diffDays <= 30) {
groups["Previous 30 days"].push(chatSession);
groups["Previous 30 Days"].push(chatSession);
} else {
groups["Over 30 days"].push(chatSession);
groups["Over 30 days ago"].push(chatSession);
}
});
@@ -424,10 +424,9 @@ export function processRawChatHistory(
message: messageInfo.message,
type: messageInfo.message_type as "user" | "assistant",
files: messageInfo.files,
alternateAssistantID:
messageInfo.alternate_assistant_id !== null
? Number(messageInfo.alternate_assistant_id)
: null,
alternateAssistantID: messageInfo.alternate_assistant_id
? Number(messageInfo.alternate_assistant_id)
: null,
// only include these fields if this is an assistant message so that
// this is identical to what is computed at streaming time
...(messageInfo.message_type === "assistant"

View File

@@ -162,6 +162,7 @@ function FileDisplay({
export const AIMessage = ({
regenerate,
overriddenModel,
selectedMessageForDocDisplay,
continueGenerating,
shared,
isActive,
@@ -169,6 +170,7 @@ export const AIMessage = ({
alternativeAssistant,
docs,
messageId,
documentSelectionToggled,
content,
files,
selectedDocuments,
@@ -178,6 +180,7 @@ export const AIMessage = ({
isComplete,
hasDocs,
handleFeedback,
handleShowRetrieved,
handleSearchQueryEdit,
handleForceSearch,
retrievalDisabled,
@@ -189,6 +192,7 @@ export const AIMessage = ({
toggledDocumentSidebar,
}: {
index?: number;
selectedMessageForDocDisplay?: number | null;
shared?: boolean;
isActive?: boolean;
continueGenerating?: () => void;
@@ -201,6 +205,7 @@ export const AIMessage = ({
currentPersona: Persona;
messageId: number | null;
content: string | JSX.Element;
documentSelectionToggled?: boolean;
files?: FileDescriptor[];
query?: string;
citedDocuments?: [string, OnyxDocument][] | null;
@@ -209,6 +214,7 @@ export const AIMessage = ({
toggledDocumentSidebar?: boolean;
hasDocs?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;
handleShowRetrieved?: (messageNumber: number | null) => void;
handleSearchQueryEdit?: (query: string) => void;
handleForceSearch?: () => void;
retrievalDisabled?: boolean;
@@ -596,7 +602,7 @@ export const AIMessage = ({
className={`
flex md:flex-row gap-x-0.5 mt-1
transition-transform duration-300 ease-in-out
transform opacity-100 "
transform opacity-100 translate-y-0"
`}
>
<TooltipGroup>
@@ -686,6 +692,10 @@ export const AIMessage = ({
settings?.isMobile) &&
"!opacity-100"
}
translate-y-2 ${
(isHovering || settings?.isMobile) && "!translate-y-0"
}
transition-transform duration-300 ease-in-out
flex md:flex-row gap-x-0.5 bg-background-125/40 -mx-1.5 p-1.5 rounded-lg
`}
>

Some files were not shown because too many files have changed in this diff Show More