Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
08b26c3227 update folder logic 2024-12-14 17:00:22 -08:00
pablodanswer
2cc72255d2 cloud settings -> billing 2024-12-14 17:00:22 -08:00
118 changed files with 797 additions and 2936 deletions

View File

@@ -8,29 +8,18 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
DOCKER_BUILDKIT: 1
BUILDKIT_PROGRESS: plain
jobs:
build-amd64:
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: System Info
run: |
df -h
free -h
docker system prune -af --volumes
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Login to Docker Hub
uses: docker/login-action@v3
@@ -38,80 +27,24 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and Push AMD64
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
outputs: type=registry
provenance: false
build-arm64:
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: System Info
run: |
df -h
free -h
docker system prune -af --volumes
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and Push ARM64
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
push: true
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
outputs: type=registry
provenance: false
merge-and-scan:
needs: [build-amd64, build-arm64]
runs-on: ubuntu-latest
steps:
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create and Push Multi-arch Manifest
run: |
docker buildx create --use
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
ONYX_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
@@ -120,4 +53,3 @@ jobs:
with:
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

View File

@@ -15,12 +15,7 @@ jobs:
# See https://runs-on.com/runners/linux/
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -201,12 +196,7 @@ jobs:
needs: playwright-tests
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4

View File

@@ -20,7 +20,8 @@ env:
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on:
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4

View File

@@ -1,48 +1,38 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine.base import Connection
import os
import ssl
from typing import Literal
import asyncio
import logging
from logging.config import fileConfig
import logging
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import MULTI_TENANT
from onyx.db.engine import build_connection_string
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# Add your model's MetaData object here for 'autogenerate' support
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
if USE_IAM_AUTH:
if not os.path.exists(SSL_CERT_FILE):
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
# Set up logging
logger = logging.getLogger(__name__)
def include_object(
@@ -59,12 +49,20 @@ def include_object(
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -92,12 +90,16 @@ def get_schema_options() -> tuple[str, bool, bool]:
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure(
@@ -115,25 +117,11 @@ def do_run_migrations(
context.run_migrations()
def provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
# Get IAM authentication token
token = get_iam_auth_token(host, port, user, region)
# For Alembic / SQLAlchemy in this context, set SSL and password
cparams["password"] = token
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
@@ -141,16 +129,10 @@ async def run_async_migrations() -> None:
poolclass=pool.NullPool,
)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
@@ -180,20 +162,15 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
@@ -230,6 +207,9 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations())

View File

@@ -1,121 +0,0 @@
"""properly_cascade
Revision ID: 35e518e0ddf4
Revises: 91a0a4d62b14
Create Date: 2024-09-20 21:24:04.891018
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "35e518e0ddf4"
down_revision = "91a0a4d62b14"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Update chat_message foreign key constraint
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
ondelete="CASCADE",
)
# Update chat_message__search_doc foreign key constraints
op.drop_constraint(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.create_foreign_key(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
"chat_message",
["chat_message_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
"search_doc",
["search_doc_id"],
["id"],
ondelete="CASCADE",
)
# Add CASCADE delete for tool_call foreign key
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.create_foreign_key(
"tool_call_message_id_fkey",
"tool_call",
"chat_message",
["message_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Revert chat_message foreign key constraint
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
)
# Revert chat_message__search_doc foreign key constraints
op.drop_constraint(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.drop_constraint(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.create_foreign_key(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
"chat_message",
["chat_message_id"],
["id"],
)
op.create_foreign_key(
"chat_message__search_doc_search_doc_id_fkey",
"chat_message__search_doc",
"search_doc",
["search_doc_id"],
["id"],
)
# Revert tool_call foreign key constraint
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.create_foreign_key(
"tool_call_message_id_fkey",
"tool_call",
"chat_message",
["message_id"],
["id"],
)

View File

@@ -1,45 +0,0 @@
"""Milestone
Revision ID: 91a0a4d62b14
Revises: dab04867cd88
Create Date: 2024-12-13 19:03:30.947551
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "91a0a4d62b14"
down_revision = "dab04867cd88"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"milestone",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=True),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("event_type", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
)
def downgrade() -> None:
op.drop_table("milestone")

View File

@@ -1,87 +0,0 @@
"""delete workspace
Revision ID: c0aab6edb6dd
Revises: 35e518e0ddf4
Create Date: 2024-12-17 14:37:07.660631
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "c0aab6edb6dd"
down_revision = "35e518e0ddf4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = connector_specific_config - 'workspace'
WHERE source = 'SLACK'
"""
)
def downgrade() -> None:
import json
from sqlalchemy import text
from slack_sdk import WebClient
conn = op.get_bind()
# Fetch all Slack credentials
creds_result = conn.execute(
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
)
all_slack_creds = creds_result.fetchall()
if not all_slack_creds:
return
for cred_row in all_slack_creds:
credential_id, credential_json = cred_row
credential_json = (
credential_json.tobytes().decode("utf-8")
if isinstance(credential_json, memoryview)
else credential_json.decode("utf-8")
)
credential_data = json.loads(credential_json)
slack_bot_token = credential_data.get("slack_bot_token")
if not slack_bot_token:
print(
f"No slack_bot_token found for credential {credential_id}. "
"Your Slack connector will not function until you upgrade and provide a valid token."
)
continue
client = WebClient(token=slack_bot_token)
try:
auth_response = client.auth_test()
workspace = auth_response["url"].split("//")[1].split(".")[0]
# Update only the connectors linked to this credential
# (and which are Slack connectors).
op.execute(
f"""
UPDATE connector AS c
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{{workspace}}',
to_jsonb('{workspace}'::text)
)
FROM connector_credential_pair AS ccp
WHERE ccp.connector_id = c.id
AND c.source = 'SLACK'
AND ccp.credential_id = {credential_id}
"""
)
except Exception:
print(
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
)
print("This connector will no longer work until you upgrade.")
continue

View File

@@ -47,11 +47,3 @@ OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", ""
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")

View File

@@ -3,15 +3,12 @@ import logging
import uuid
import aiohttp # Async HTTP client
import httpx
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantCreationPayload
@@ -23,7 +20,6 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
@@ -39,27 +35,22 @@ from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
@@ -131,17 +122,6 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
@@ -287,36 +267,3 @@ def configure_default_api_keys(db_session: Session) -> None:
logger.info(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)
async def submit_to_hubspot(
email: str, referral_source: str | None, request: Request
) -> None:
if not HUBSPOT_TRACKING_URL:
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
return
# HubSpot tracking cookie
hubspot_cookie = request.cookies.get("hubspotutk")
# IP address
ip_address = request.client.host if request.client else None
data = {
"fields": [
{"name": "email", "value": email},
{"name": "referral_source", "value": referral_source or ""},
],
"context": {
"hutk": hubspot_cookie,
"ipAddress": ip_address,
"pageUri": str(request.url),
"pageName": "User Registration",
},
}
async with httpx.AsyncClient() as client:
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")

View File

@@ -1,18 +0,0 @@
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
logger = setup_logger()
posthog = Posthog(project_api_key=POSTHOG_API_KEY, host=POSTHOG_HOST)
def event_telemetry(
distinct_id: str,
event: str,
properties: dict | None = None,
) -> None:
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
posthog.capture(distinct_id, event, properties)

View File

@@ -27,8 +27,8 @@ from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
HF_CACHE_PATH = Path("/root/.cache/huggingface/")
TEMP_HF_CACHE_PATH = Path("/root/.cache/temp_huggingface/")
transformer_logging.set_verbosity_error()

View File

@@ -4,8 +4,6 @@ from typing import cast
from onyx.auth.schemas import UserRole
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from onyx.configs.constants import NO_AUTH_USER_EMAIL
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.key_value_store.store import KeyValueStore
from onyx.key_value_store.store import KvKeyNotFoundError
from onyx.server.manage.models import UserInfo
@@ -32,8 +30,8 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
return UserInfo(
id=NO_AUTH_USER_ID,
email=NO_AUTH_USER_EMAIL,
id="__no_auth_user__",
email="anonymous@onyx.app",
is_active=True,
is_superuser=False,
is_verified=True,

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
@@ -73,8 +72,6 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
from onyx.db.auth import get_access_token_db
@@ -91,7 +88,6 @@ from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -229,26 +225,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
user_create.password, cast(schemas.UC, user_create)
)
user_count: int | None = None
referral_source = (
request.cookies.get("referral_source", None)
if request is not None
else None
)
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
request=request,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -291,37 +278,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to basic security guidelines
if len(password) < 12:
raise exceptions.InvalidPasswordException(
reason="Password must be at least 12 characters long."
)
if len(password) > 64:
raise exceptions.InvalidPasswordException(
reason="Password must not exceed 64 characters."
)
if not any(char.isupper() for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one uppercase letter."
)
if not any(char.islower() for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one lowercase letter."
)
if not any(char.isdigit() for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one number."
)
if not any(char in PASSWORD_SPECIAL_CHARS for char in password):
raise exceptions.InvalidPasswordException(
reason="Password must contain at least one special character from the following set: "
f"{PASSWORD_SPECIAL_CHARS}."
)
return
return user
async def oauth_callback(
self,
@@ -336,18 +293,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
referral_source = (
getattr(request.state, "referral_source", None) if request else None
)
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
request=request,
)
if not tenant_id:
@@ -409,7 +365,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
@@ -463,39 +418,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
request=request,
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if user_count == 1:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.USER_SIGNED_UP,
properties=None,
db_session=db_session,
)
else:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.MULTIPLE_USERS,
properties=None,
db_session=db_session,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.notice(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
@@ -527,7 +449,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
@@ -588,7 +510,7 @@ class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,

View File

@@ -8,8 +8,6 @@ import sentry_sdk
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
from celery.signals import task_postrun
from celery.signals import task_prerun
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
@@ -36,11 +34,8 @@ from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -61,8 +56,8 @@ def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
@@ -351,36 +346,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
def on_setup_logging(
loglevel: int,
logfile: str | None,
format: str,
colorize: bool,
**kwargs: Any,
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
root_logger = logging.getLogger()
root_logger.handlers = []
# Define the log format
log_format = (
"%(levelname)-8s %(asctime)s %(filename)15s:%(lineno)-4d: %(name)s %(message)s"
)
# Set up the root handler
root_handler = logging.StreamHandler()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
root_formatter = ColoredFormatter(
log_format,
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_handler.setFormatter(root_formatter)
root_logger.addHandler(root_handler)
root_logger.addHandler(root_handler) # Apply the handler to the root logger
if logfile:
root_file_handler = logging.FileHandler(logfile)
root_file_formatter = PlainFormatter(
log_format,
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
root_file_handler.setFormatter(root_file_formatter)
@@ -388,23 +373,19 @@ def on_setup_logging(
root_logger.setLevel(loglevel)
# Configure the task logger
task_logger.handlers = []
task_handler = logging.StreamHandler()
task_handler.addFilter(TenantContextFilter())
# reformats celery's task logger
task_formatter = CeleryTaskColoredFormatter(
log_format,
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_handler = logging.StreamHandler() # Set up a handler for the task logger
task_handler.setFormatter(task_formatter)
task_logger.addHandler(task_handler)
task_logger.addHandler(task_handler) # Apply the handler to the task logger
if logfile:
task_file_handler = logging.FileHandler(logfile)
task_file_handler.addFilter(TenantContextFilter())
task_file_formatter = CeleryTaskPlainFormatter(
log_format,
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
task_file_handler.setFormatter(task_file_formatter)
@@ -413,55 +394,10 @@ def on_setup_logging(
task_logger.setLevel(loglevel)
task_logger.propagate = False
# Hide celery task received and succeeded/failed messages
# hide celery task received spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
strategy.logger.setLevel(logging.WARNING)
# hide celery task succeeded/failed spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
trace.logger.setLevel(logging.WARNING)
class TenantContextFilter(logging.Filter):
"""Logging filter to inject tenant ID into the logger's name."""
def filter(self, record: logging.LogRecord) -> bool:
if not MULTI_TENANT:
record.name = ""
return True
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id:
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
record.name = f"[t:{tenant_id}]"
else:
record.name = ""
return True
@task_prerun.connect
def set_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to set tenant ID in context var before task starts."""
tenant_id = (
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if kwargs
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@task_postrun.connect
def reset_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to reset tenant ID in context var after task ends."""
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)

View File

@@ -44,18 +44,18 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating task update")
logger.info("Reload interval reached, initiating tenant task update")
self._update_tenant_tasks()
self._last_reload = now
logger.info("Task update completed, reset reload timer")
logger.info("Tenant task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting task update process")
logger.info("Starting tenant task update process")
try:
logger.info("Fetching all IDs")
logger.info("Fetching all tenant IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
logger.info(f"Found {len(tenant_ids)} tenants")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
@@ -70,7 +70,7 @@ class DynamicTenantScheduler(PersistentScheduler):
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
@@ -83,7 +83,7 @@ class DynamicTenantScheduler(PersistentScheduler):
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new item: {tenant_id}")
logger.info(f"Processing new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
@@ -129,10 +129,11 @@ class DynamicTenantScheduler(PersistentScheduler):
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict

View File

@@ -1,6 +1,4 @@
# These are helper objects for tracking the keys we need to write in redis
import json
from typing import Any
from typing import cast
from redis import Redis
@@ -25,25 +23,3 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
total_length += cast(int, length)
return total_length
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
"""This is a redis specific way to find a task for a particular queue in redis.
It is priority aware and knows how to look through the multiple redis lists
used to implement task prioritization.
This operation is not atomic.
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
Returns true if the id is in the queue, False if not.
"""
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
if task_dict.get("headers", {}).get("id") == task_id:
return True
return False

View File

@@ -4,80 +4,55 @@ from typing import Any
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
"options": {"priority": OnyxCeleryPriority.HIGH},
},
]

View File

@@ -76,7 +76,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during connector deletion check")
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
@@ -131,14 +131,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
if redis_connector.prune.fenced:
raise TaskDependencyError(
"Connector deletion - Delayed (pruning in progress): "
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
@@ -175,7 +175,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
# return 0
task_logger.info(
"RedisConnectorDeletion.generate_tasks finished. "
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)

View File

@@ -1,9 +1,7 @@
import time
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import Any
import redis
import sentry_sdk
@@ -17,7 +15,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
@@ -29,7 +26,6 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -166,19 +162,11 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -283,7 +271,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings_instance,
reindex,
db_session,
redis_client,
r,
tenant_id,
)
if attempt_id:
@@ -298,9 +286,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
@@ -318,28 +304,12 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# clear any indexing fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
task_logger.info("Validating indexing fences...")
validate_indexing_fences(
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during indexing check")
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if locked:
if lock_beat.owned():
@@ -350,190 +320,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"tenant={tenant_id}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
def validate_indexing_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks: set[str] = set()
active_indexing_tasks: set[str] = set()
indexing_worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = celery_app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if not workers:
raise ValueError("No workers found!")
for worker_name in list(workers.keys()):
if "indexing" in worker_name:
indexing_worker_names.append(worker_name)
if len(indexing_worker_names) == 0:
raise ValueError("No indexing workers found!")
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
# NOTE: each dict entry is a map of worker name to a list of tasks
# we want sets for reserved task and active task id's to optimize
# subsequent validation lookups
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
if reserved_tasks is None:
raise ValueError("inspect_indexing.reserved() returned None!")
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_indexing_tasks.add(task["id"])
# get the list of active tasks
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
if active_tasks is None:
raise ValueError("inspect_indexing.active() returned None!")
for _, task_list in active_tasks.items():
for task in task_list:
active_indexing_tasks.add(task["id"])
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
active_indexing_tasks,
r_celery,
db_session,
)
return
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
active_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. Active signal is renewed with a 5 minute TTL
1.1 When the fence is created
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved or active list for a worker
2. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
if payload.celery_task_id in active_tasks:
# the celery task is active (aka currently executing)
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
# due to gaps in our ability to check states during transitions
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed."
)
redis_connector_index.reset()
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
@@ -680,7 +469,6 @@ def try_creating_indexing_task(
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
@@ -714,14 +502,13 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
@@ -753,6 +540,7 @@ def connector_indexing_proxy_task(
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -775,14 +563,15 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -797,6 +586,7 @@ def connector_indexing_proxy_task(
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -856,7 +646,7 @@ def connector_indexing_proxy_task(
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
@@ -891,6 +681,7 @@ def connector_indexing_proxy_task(
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
@@ -1086,7 +877,6 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
@@ -1116,6 +906,7 @@ def connector_indexing_task(
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)

View File

@@ -122,7 +122,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during pruning check")
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
@@ -308,7 +308,7 @@ def connector_pruning_generator_task(
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
task_logger.info(
"Pruning set collected: "
f"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
@@ -324,7 +324,7 @@ def connector_pruning_generator_task(
return None
task_logger.info(
"RedisConnector.prune.generate_tasks finished. "
f"RedisConnector.prune.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)

View File

@@ -60,7 +60,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.debug(f"Task start: doc={document_id}")
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -129,13 +129,16 @@ def document_by_cc_pair_cleanup_task(
db_session.commit()
task_logger.info(
f"tenant={tenant_id} "
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
)
return False
except Exception as ex:
if isinstance(ex, RetryError):
@@ -154,12 +157,15 @@ def document_by_cc_pair_cleanup_task(
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"tenant={tenant_id} "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(f"Unexpected exception: doc={document_id}")
task_logger.exception(
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
)
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
@@ -170,7 +176,7 @@ def document_by_cc_pair_cleanup_task(
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up

View File

@@ -1,4 +1,3 @@
import time
import traceback
from datetime import datetime
from datetime import timezone
@@ -90,11 +89,10 @@ logger = setup_logger()
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -158,15 +156,11 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during vespa metadata sync")
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
@@ -736,7 +730,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -831,8 +824,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True
@@ -882,9 +873,13 @@ def vespa_metadata_sync_task(
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
task_logger.info(
f"tenant={tenant_id} doc={document_id} action=sync chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
)
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
@@ -902,13 +897,14 @@ def vespa_metadata_sync_task(
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"tenant={tenant_id} "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(
f"Unexpected exception during vespa metadata sync: doc={document_id}"
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64

View File

@@ -11,7 +11,6 @@ from onyx.background.indexing.tracer import OnyxTracer
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import IndexAttemptMetadata
@@ -35,7 +34,6 @@ from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
@@ -398,15 +396,6 @@ def _run_indexing(
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session,
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"

View File

@@ -31,8 +31,6 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
@@ -55,9 +53,6 @@ from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@@ -122,7 +117,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -362,31 +356,6 @@ def stream_chat_message_objects(
if not persona:
raise RuntimeError("No persona specified or found for chat session")
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user,
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
db_session=db_session,
)
update_user_assistant_milestone(
milestone=multi_assistant_milestone,
user_id=str(user.id) if user else NO_AUTH_USER_ID,
assistant_id=persona.id,
db_session=db_session,
)
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
milestone=multi_assistant_milestone,
db_session=db_session,
)
if just_hit_multi_assistant_milestone:
mt_cloud_telemetry(
distinct_id=tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
properties=None,
)
# If a prompt override is specified via the API, use that with highest priority
# but for saving it, we are just mapping it to an existing prompt
prompt_id = new_msg_req.prompt_id

View File

@@ -65,7 +65,7 @@ class CitationProcessor:
# Handle code blocks without language tags
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
pass
return
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):

View File

@@ -1,7 +1,6 @@
import json
import os
import urllib.parse
from typing import cast
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
@@ -145,7 +144,6 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
@@ -176,9 +174,6 @@ try:
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
@@ -488,21 +483,6 @@ SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
# allow for custom error messages for different errors returned by litellm
# for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"}
# to make it so that if an LLM call returns an error containing "Violated content safety policy"
# the end user will see "EVIL REQUEST!!!" instead of the default error message.
_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get(
"LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", ""
)
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None
try:
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast(
dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS)
)
except json.JSONDecodeError:
pass
#####
# Enterprise Edition Configs
#####

View File

@@ -63,10 +63,6 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Number of prompts each persona should have
NUM_PERSONA_PROMPTS = 4
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.

View File

@@ -15,9 +15,6 @@ ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
NO_AUTH_USER_ID = "__no_auth_user__"
NO_AUTH_USER_EMAIL = "anonymous@onyx.app"
# For chunking/processing chunks
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
@@ -49,7 +46,6 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
SSL_CERT_FILE = "bundle.pem"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
@@ -174,10 +170,6 @@ class AuthType(str, Enum):
CLOUD = "cloud"
# Special characters for password validation
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
class SessionType(str, Enum):
CHAT = "Chat"
SEARCH = "Search"
@@ -218,19 +210,6 @@ class FileOrigin(str, Enum):
OTHER = "other"
class MilestoneRecordType(str, Enum):
TENANT_CREATED = "tenant_created"
USER_SIGNED_UP = "user_signed_up"
MULTIPLE_USERS = "multiple_users"
VISITED_ADMIN_PAGE = "visited_admin_page"
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
@@ -275,10 +254,6 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
class OnyxCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()

View File

@@ -141,20 +141,14 @@ def get_valid_messages_from_query_sessions(
return {row.chat_session_id: row.message for row in first_messages}
# Retrieves chat sessions by user
# Chat sessions do not include onyxbot flows
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_onyxbot_flows: bool = False,
limit: int = 50,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
@@ -316,23 +310,6 @@ def update_chat_session(
return chat_session
def delete_all_chat_sessions_for_user(
user: User | None, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS
) -> None:
user_id = user.id if user is not None else None
query = db_session.query(ChatSession).filter(
ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)
)
if hard_delete:
query.delete(synchronize_session=False)
else:
query.update({ChatSession.deleted: True}, synchronize_session=False)
db_session.commit()
def delete_chat_session(
user_id: UUID | None,
chat_session_id: UUID,

View File

@@ -1,7 +1,5 @@
import contextlib
import os
import re
import ssl
import threading
import time
from collections.abc import AsyncGenerator
@@ -12,8 +10,6 @@ from datetime import datetime
from typing import Any
from typing import ContextManager
import asyncpg # type: ignore
import boto3
import jwt
from fastapi import HTTPException
from fastapi import Request
@@ -27,7 +23,6 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -42,7 +37,6 @@ from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -55,87 +49,28 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()
def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-east-2"
) -> str:
"""
Generate an IAM authentication token using boto3.
"""
client = boto3.client("rds", region_name=region)
token = client.generate_db_auth_token(
DBHostname=host, Port=int(port), DBUsername=user
)
return token
def configure_psycopg2_iam_auth(
cparams: dict[str, Any], host: str, port: str, user: str, region: str
) -> None:
"""
Configure cparams for psycopg2 with IAM token and SSL.
"""
token = get_iam_auth_token(host, port, user, region)
cparams["password"] = token
cparams["sslmode"] = "require"
cparams["sslrootcert"] = SSL_CERT_FILE
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH,
region: str = "us-west-2",
) -> str:
if use_iam:
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
# For asyncpg, do not include application_name in the connection string
if app_name and db_api != "asyncpg":
if "?" in base_conn_str:
return f"{base_conn_str}&application_name={app_name}"
else:
return f"{base_conn_str}?application_name={app_name}"
return base_conn_str
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
conn.info["query_start_time"] = time.time()
# Function to log after query execution
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
total_time = time.time() - conn.info["query_start_time"]
# don't spam TOO hard
if total_time > 0.1:
logger.debug(
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
@@ -143,6 +78,7 @@ if LOG_POSTGRES_LATENCY:
if LOG_POSTGRES_CONN_COUNTS:
# Global counter for connection checkouts and checkins
checkout_count = 0
checkin_count = 0
@@ -169,13 +105,21 @@ if LOG_POSTGRES_CONN_COUNTS:
logger.debug(f"Total connection checkins: {checkin_count}")
"""END DEBUGGING LOGGING"""
def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
Within the same transaction this value will not update
This datetime object returned should be timezone aware, default Postgres timezone is UTC
"""
result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None:
raise ValueError("Database did not return a time")
return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
@@ -184,9 +128,16 @@ def is_valid_schema_name(name: str) -> bool:
class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now.
"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = {
"pool_size": 20,
"max_overflow": 5,
@@ -194,27 +145,33 @@ class SqlEngine:
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
def __init__(self) -> None:
pass
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
engine = create_engine(connection_string, **merged_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
return engine
return create_engine(connection_string, **merged_kwargs)
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine:
with cls._lock:
if not cls._engine:
@@ -223,10 +180,12 @@ class SqlEngine:
@classmethod
def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name
@classmethod
def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name:
return ""
return cls._app_name
@@ -258,71 +217,56 @@ def get_all_tenant_ids() -> list[str] | list[None]:
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
"""
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
)
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
use_iam=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args=connect_args,
connect_args={
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
# async engine is only used by API server, so we can use those values
# here as well
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
if USE_IAM_AUTH:
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
def provide_iam_token_async(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
# For async engine using asyncpg, we still need to set the IAM token here.
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION)
cparams["password"] = token
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE
# Dependency to get the current tenant ID
# If no token is present, uses the default schema for this use case
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -331,6 +275,7 @@ def get_current_tenant_id(request: Request) -> str:
token = request.cookies.get("fastapiusersauth")
if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no token is present, use the default schema or handle accordingly
return current_value
try:
@@ -344,6 +289,7 @@ def get_current_tenant_id(request: Request) -> str:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
return CURRENT_TENANT_ID_CONTEXTVAR.get()
@@ -370,6 +316,7 @@ async def get_async_session_with_tenant(
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
@@ -379,6 +326,8 @@ async def get_async_session_with_tenant(
)
except Exception:
logger.exception("Error setting search_path.")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
yield session
@@ -386,6 +335,9 @@ async def get_async_session_with_tenant(
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@@ -397,6 +349,7 @@ def get_session_with_tenant(
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
@@ -404,20 +357,27 @@ def get_session_with_tenant(
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
@@ -430,17 +390,21 @@ def get_session_with_tenant(
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
@@ -460,9 +424,12 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(detail="User must authenticate")
raise BasicAuthenticationError(
detail="User must authenticate",
)
engine = get_sqlalchemy_engine()
@@ -470,17 +437,20 @@ def get_session() -> Generator[Session, None, None]:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
@@ -491,6 +461,7 @@ def get_session_context_manager() -> ContextManager[Session]:
def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
@@ -518,13 +489,3 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
if USE_IAM_AUTH:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)

View File

@@ -1,99 +0,0 @@
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from onyx.configs.constants import MilestoneRecordType
from onyx.db.models import Milestone
from onyx.db.models import User
USER_ASSISTANT_PREFIX = "user_assistants_used_"
MULTI_ASSISTANT_USED = "multi_assistant_used"
def create_milestone(
user: User | None,
event_type: MilestoneRecordType,
db_session: Session,
) -> Milestone:
milestone = Milestone(
event_type=event_type,
user_id=user.id if user else None,
)
db_session.add(milestone)
db_session.commit()
return milestone
def create_milestone_if_not_exists(
user: User | None, event_type: MilestoneRecordType, db_session: Session
) -> tuple[Milestone, bool]:
# Check if it exists
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one_or_none()
if milestone is not None:
return milestone, False
# If it doesn't exist, try to create it.
try:
milestone = create_milestone(user, event_type, db_session)
return milestone, True
except IntegrityError:
# Another thread or process inserted it in the meantime
db_session.rollback()
# Fetch again to return the existing record
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one() # Now should exist
return milestone, False
def update_user_assistant_milestone(
milestone: Milestone,
user_id: str | None,
assistant_id: int,
db_session: Session,
) -> None:
event_tracker = milestone.event_tracker
if event_tracker is None:
milestone.event_tracker = event_tracker = {}
if event_tracker.get(MULTI_ASSISTANT_USED):
# No need to keep tracking and populating if the milestone has already been hit
return
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
if event_tracker.get(user_key) is None:
event_tracker[user_key] = [assistant_id]
elif assistant_id not in event_tracker[user_key]:
event_tracker[user_key].append(assistant_id)
flag_modified(milestone, "event_tracker")
db_session.commit()
def check_multi_assistant_milestone(
milestone: Milestone,
db_session: Session,
) -> tuple[bool, bool]:
"""Returns if the milestone was hit and if it was just hit for the first time"""
event_tracker = milestone.event_tracker
if event_tracker is None:
return False, False
if event_tracker.get(MULTI_ASSISTANT_USED):
return True, False
for key, value in event_tracker.items():
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
event_tracker[MULTI_ASSISTANT_USED] = True
flag_modified(milestone, "event_tracker")
db_session.commit()
return True, True
return False, False

View File

@@ -5,8 +5,6 @@ from typing import Literal
from typing import NotRequired
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel
from typing_extensions import TypedDict # noreorder
from uuid import UUID
@@ -39,7 +37,7 @@ from sqlalchemy.types import TypeDecorator
from onyx.auth.schemas import UserRole
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
@@ -1010,7 +1008,7 @@ class ChatSession(Base):
"ChatFolder", back_populates="chat_sessions"
)
messages: Mapped[list["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
"ChatMessage", back_populates="chat_session"
)
persona: Mapped["Persona"] = relationship("Persona")
@@ -1078,8 +1076,6 @@ class ChatMessage(Base):
"SearchDoc",
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
cascade="all, delete-orphan",
single_parent=True,
)
tool_call: Mapped["ToolCall"] = relationship(
@@ -1348,11 +1344,6 @@ class StarterMessage(TypedDict):
message: str
class StarterMessageModel(BaseModel):
name: str
message: str
class Persona(Base):
__tablename__ = "persona"
@@ -1543,32 +1534,6 @@ class SlackBot(Base):
)
class Milestone(Base):
# This table is used to track significant events for a deployment towards finding value
# The table is currently not used for features but it may be used in the future to inform
# users about the product features and encourage usage/exploration.
__tablename__ = "milestone"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
event_type: Mapped[MilestoneRecordType] = mapped_column(String)
# Need to track counts and specific ids of certain events to know if the Milestone has been reached
event_tracker: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User | None] = relationship("User")
__table_args__ = (UniqueConstraint("event_type", name="uq_milestone_event_type"),)
class TaskQueueState(Base):
# Currently refers to Celery Tasks
__tablename__ = "task_queue_jobs"

View File

@@ -543,10 +543,6 @@ def upsert_persona(
if tools is not None:
existing_persona.tools = tools or []
# We should only update display priority if it is not already set
if existing_persona.display_priority is None:
existing_persona.display_priority = display_priority
persona = existing_persona
else:

View File

@@ -369,19 +369,6 @@ class AdminCapable(abc.ABC):
raise NotImplementedError
class RandomCapable(abc.ABC):
"""Class must implement random document retrieval capability"""
@abc.abstractmethod
def random_retrieval(
self,
filters: IndexFilters,
num_to_retrieve: int = 10,
) -> list[InferenceChunkUncleaned]:
"""Retrieve random chunks matching the filters"""
raise NotImplementedError
class BaseIndex(
Verifiable,
Indexable,
@@ -389,7 +376,6 @@ class BaseIndex(
Deletable,
AdminCapable,
IdRetrievalCapable,
RandomCapable,
abc.ABC,
):
"""

View File

@@ -218,10 +218,4 @@ schema DANSWER_CHUNK_NAME {
expression: bm25(content) + (5 * bm25(title))
}
}
rank-profile random_ {
first-phase {
expression: random.match
}
}
}

View File

@@ -23,7 +23,7 @@
<resource-limits>
<!-- Default is 75% but this can be increased for Dockerized deployments -->
<!-- https://docs.vespa.ai/en/operations/feed-block.html -->
<disk>0.85</disk>
<disk>0.75</disk>
</resource-limits>
</tuning>
<engine>

View File

@@ -2,7 +2,6 @@ import concurrent.futures
import io
import logging
import os
import random
import re
import time
import urllib
@@ -904,32 +903,6 @@ class VespaIndex(DocumentIndex):
logger.info("Batch deletion completed")
def random_retrieval(
self,
filters: IndexFilters,
num_to_retrieve: int = 10,
) -> list[InferenceChunkUncleaned]:
"""Retrieve random chunks matching the filters using Vespa's random ranking
This method is currently used for random chunk retrieval in the context of
assistant starter message creation (passed as sample context for usage by the assistant).
"""
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
random_seed = random.randint(0, 1000000)
params: dict[str, str | int | float] = {
"yql": yql,
"hits": num_to_retrieve,
"timeout": VESPA_TIMEOUT,
"ranking.profile": "random_",
"ranking.properties.random.seed": random_seed,
}
return query_vespa(params)
class _VespaDeleteRequest:
def __init__(self, document_id: str, index_name: str) -> None:

View File

@@ -19,12 +19,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def build_vespa_filters(
filters: IndexFilters,
*,
include_hidden: bool = False,
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
) -> str:
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
if vals is None:
return ""
@@ -83,9 +78,6 @@ def build_vespa_filters(
filter_str += _build_time_filter(filters.time_cutoff)
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5] # We remove the trailing " and "
return filter_str

View File

@@ -28,7 +28,6 @@ from litellm.exceptions import RateLimitError # type: ignore
from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -46,19 +45,10 @@ logger = setup_logger()
def litellm_exception_to_error_msg(
e: Exception,
llm: LLM,
fallback_to_error_msg: bool = False,
custom_error_msg_mappings: dict[str, str]
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
) -> str:
error_msg = str(e)
if custom_error_msg_mappings:
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
if error_msg_pattern in error_msg:
return custom_error_msg
if isinstance(e, BadRequestError):
error_msg = "Bad request: The server couldn't process your request. Please check your input."
elif isinstance(e, AuthenticationError):

View File

@@ -243,7 +243,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_query_router)
include_router_with_global_prefix_prepended(application, admin_router)
include_router_with_global_prefix_prepended(application, connector_router)
include_router_with_global_prefix_prepended(application, user_router)
include_router_with_global_prefix_prepended(application, credential_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, folder_router)

View File

@@ -1,46 +0,0 @@
PERSONA_CATEGORY_GENERATION_PROMPT = """
Based on the assistant's name, description, and instructions, generate a list of {num_categories}
**unique and diverse** categories that represent different types of starter messages a user
might send to initiate a conversation with this chatbot assistant.
**Ensure that the categories are varied and cover a wide range of topics related to the assistant's capabilities.**
Provide the categories as a JSON array of strings **without any code fences or additional text**.
**Context about the assistant:**
- **Name**: {name}
- **Description**: {description}
- **Instructions**: {instructions}
""".strip()
PERSONA_STARTER_MESSAGE_CREATION_PROMPT = """
Create a starter message that a **user** might send to initiate a conversation with a chatbot assistant.
**Category**: {category}
Your response should include two parts:
1. **Title**: A short, engaging title that reflects the user's intent
(e.g., 'Need Travel Advice', 'Question About Coding', 'Looking for Book Recommendations').
2. **Message**: The actual message that the user would send to the assistant.
This should be natural, engaging, and encourage a helpful response from the assistant.
**Avoid overly specific details; keep the message general and broadly applicable.**
For example:
- Instead of "I've just adopted a 6-month-old Labrador puppy who's pulling on the leash,"
write "I'm having trouble training my new puppy to walk nicely on a leash."
Ensure each part is clearly labeled and separated as shown above.
Do not provide any additional text or explanation and be extremely concise
**Context about the assistant:**
- **Name**: {name}
- **Description**: {description}
- **Instructions**: {instructions}
""".strip()
if __name__ == "__main__":
print(PERSONA_CATEGORY_GENERATION_PROMPT)
print(PERSONA_STARTER_MESSAGE_CREATION_PROMPT)

View File

@@ -31,10 +31,6 @@ class RedisConnectorIndex:
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
# used to signal the overall workflow is still active
# it's difficult to prevent
ACTIVE_PREFIX = PREFIX + "_active"
def __init__(
self,
tenant_id: str | None,
@@ -58,7 +54,6 @@ class RedisConnectorIndex:
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -112,26 +107,6 @@ class RedisConnectorIndex:
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=300)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
@@ -163,7 +138,6 @@ class RedisConnectorIndex:
return status
def reset(self) -> None:
self.redis.delete(self.active_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)

View File

@@ -1,271 +0,0 @@
import json
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from litellm import get_supported_openai_params
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
from onyx.configs.chat_configs import NUM_PERSONA_PROMPTS
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.db.document_set import get_document_sets_by_ids
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llms
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
from onyx.prompts.starter_messages import PERSONA_STARTER_MESSAGE_CREATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
logger = setup_logger()
def get_random_chunks_from_doc_sets(
doc_sets: List[str], db_session: Session, user: User | None = None
) -> List[InferenceChunk]:
"""
Retrieves random chunks from the specified document sets.
"""
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
acl_filters = build_access_filters_for_user(user, db_session)
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)
chunks = document_index.random_retrieval(
filters=filters, num_to_retrieve=NUM_PERSONA_PROMPT_GENERATION_CHUNKS
)
return cleanup_chunks(chunks)
def parse_categories(content: str) -> List[str]:
"""
Parses the JSON array of categories from the LLM response.
"""
# Clean the response to remove code fences and extra whitespace
content = content.strip().strip("```").strip()
if content.startswith("json"):
content = content[4:].strip()
try:
categories = json.loads(content)
if not isinstance(categories, list):
logger.error("Categories are not a list.")
return []
return categories
except json.JSONDecodeError as e:
logger.error(f"Failed to parse categories: {e}")
return []
def generate_start_message_prompts(
name: str,
description: str,
instructions: str,
categories: List[str],
chunk_contents: str,
supports_structured_output: bool,
fast_llm: Any,
) -> List[FunctionCall]:
"""
Generates the list of FunctionCall objects for starter message generation.
"""
functions = []
for category in categories:
# Create a prompt specific to the category
start_message_generation_prompt = (
PERSONA_STARTER_MESSAGE_CREATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
category=category,
)
)
if chunk_contents:
start_message_generation_prompt += (
"\n\nExample content this assistant has access to:\n"
"'''\n"
f"{chunk_contents}"
"\n'''"
)
if supports_structured_output:
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt, None, None, StarterMessage),
)
)
else:
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt,),
)
)
return functions
def parse_unstructured_output(output: str) -> Dict[str, str]:
"""
Parses the assistant's unstructured output into a dictionary with keys:
- 'name' (Title)
- 'message' (Message)
"""
# Debug output
logger.debug(f"LLM Output for starter message creation: {output}")
# Patterns to match
title_pattern = r"(?i)^\**Title\**\s*:\s*(.+)"
message_pattern = r"(?i)^\**Message\**\s*:\s*(.+)"
# Initialize the response dictionary
response_dict = {}
# Split the output into lines
lines = output.strip().split("\n")
# Variables to keep track of the current key being processed
current_key = None
current_value_lines = []
for line in lines:
# Check for title
title_match = re.match(title_pattern, line.strip())
if title_match:
# Save previous key-value pair if any
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
current_value_lines = []
current_key = "name"
current_value_lines.append(title_match.group(1).strip())
continue
# Check for message
message_match = re.match(message_pattern, line.strip())
if message_match:
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
current_value_lines = []
current_key = "message"
current_value_lines.append(message_match.group(1).strip())
continue
# If the line doesn't match a new key, append it to the current value
if current_key:
current_value_lines.append(line.strip())
# Add the last key-value pair
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
# Validate that the necessary keys are present
if not all(k in response_dict for k in ["name", "message"]):
raise ValueError("Failed to parse the assistant's response.")
return response_dict
def generate_starter_messages(
name: str,
description: str,
instructions: str,
document_set_ids: List[int],
db_session: Session,
user: User | None,
) -> List[StarterMessage]:
"""
Generates starter messages by first obtaining categories and then generating messages for each category.
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
"""
_, fast_llm = get_default_llms(temperature=0.5)
provider = fast_llm.config.model_provider
model = fast_llm.config.model_name
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
supports_structured_output = (
isinstance(params, list) and "response_format" in params
)
# Generate categories
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
num_categories=NUM_PERSONA_PROMPTS,
)
category_response = fast_llm.invoke(category_generation_prompt)
categories = parse_categories(cast(str, category_response.content))
if not categories:
logger.error("No categories were generated.")
return []
# Fetch example content if document sets are provided
if document_set_ids:
document_sets = get_document_sets_by_ids(
document_set_ids=document_set_ids,
db_session=db_session,
)
chunks = get_random_chunks_from_doc_sets(
doc_sets=[doc_set.name for doc_set in document_sets],
db_session=db_session,
user=user,
)
# Add example content context
chunk_contents = "\n".join(chunk.content.strip() for chunk in chunks)
else:
chunk_contents = ""
# Generate prompts for starter messages
functions = generate_start_message_prompts(
name,
description,
instructions,
categories,
chunk_contents,
supports_structured_output,
fast_llm,
)
# Run LLM calls in parallel
if not functions:
logger.error("No functions to execute for starter message generation.")
return []
results = run_functions_in_parallel(function_calls=functions)
prompts = []
for response in results.values():
try:
if supports_structured_output:
response_dict = json.loads(response.content)
else:
response_dict = parse_unstructured_output(response.content)
starter_message = StarterMessage(
name=response_dict["name"],
message=response_dict["message"],
)
prompts.append(starter_message)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse starter message: {e}")
continue
return prompts

View File

@@ -9,7 +9,6 @@ from onyx.access.models import default_public_access
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_DOCUMENTS_SEEDED_KEY
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
@@ -72,7 +71,7 @@ def _create_indexable_chunks(
source_links={0: preprocessed_doc["url"]},
section_continuation=False,
source_document=document,
title_prefix=preprocessed_doc["title"] + RETURN_SEPARATOR,
title_prefix=preprocessed_doc["title"],
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
@@ -217,7 +216,7 @@ def seed_initial_documents(
# as we just sent over the Vespa schema and there is a slight delay
index_with_retries = retry_builder()(document_index.index)
index_with_retries(chunks=chunks, fresh_index=cohere_enabled)
index_with_retries(chunks=chunks, fresh_index=True)
# Mock a run for the UI even though it did not actually call out to anything
mock_successful_index_attempt(

View File

@@ -48,7 +48,6 @@ def load_personas_from_yaml(
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
@@ -128,7 +127,6 @@ def load_personas_from_yaml(
display_priority=(
existing_persona.display_priority
if existing_persona is not None
and persona.get("display_priority") is None
else persona.get("display_priority")
),
is_visible=(

View File

@@ -7,7 +7,7 @@ personas:
- id: 0
name: "Search"
description: >
Assistant with access to documents and knowledge from Connected Sources.
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml
prompts:
- "Answer-Question"
@@ -39,7 +39,7 @@ personas:
document_sets: []
icon_shape: 23013
icon_color: "#6FB1FF"
display_priority: 0
display_priority: 1
is_visible: true
starter_messages:
- name: "Give me an overview of what's here"
@@ -54,7 +54,7 @@ personas:
- id: 1
name: "General"
description: >
Assistant with no search functionalities. Chat directly with the Large Language Model.
Assistant with no access to documents. Chat with just the Large Language Model.
prompts:
- "OnlyLLM"
num_chunks: 0
@@ -64,7 +64,7 @@ personas:
document_sets: []
icon_shape: 50910
icon_color: "#FF6F6F"
display_priority: 1
display_priority: 0
is_visible: true
starter_messages:
- name: "Summarize a document"

View File

@@ -21,7 +21,6 @@ from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.google_utils.google_auth import (
@@ -111,7 +110,6 @@ from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -641,15 +639,6 @@ def get_connector_indexing_status(
)
)
# Visiting admin page brings the user to the current connectors page which calls this endpoint
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.VISITED_ADMIN_PAGE,
properties=None,
db_session=db_session,
)
return indexing_statuses
@@ -674,7 +663,6 @@ def create_connector_from_model(
connector_data: ConnectorUpdateRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> ObjectCreationIdResponse:
try:
_validate_connector_allowed(connector_data.source)
@@ -689,20 +677,10 @@ def create_connector_from_model(
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
)
connector_base = connector_data.to_connector_base()
connector_response = create_connector(
return create_connector(
db_session=db_session,
connector_data=connector_base,
)
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_CONNECTOR,
properties=None,
db_session=db_session,
)
return connector_response
except ValueError as e:
logger.error(f"Error creating connector: {e}")
raise HTTPException(status_code=400, detail=str(e))
@@ -713,7 +691,6 @@ def create_connector_with_mock_credential(
connector_data: ConnectorUpdateRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
@@ -751,15 +728,6 @@ def create_connector_with_mock_credential(
cc_pair_name=connector_data.name,
groups=connector_data.groups,
)
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_CONNECTOR,
properties=None,
db_session=db_session,
)
return response
except ValueError as e:

View File

@@ -15,11 +15,8 @@ from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.chat.prompt_builder.utils import build_dummy_prompt
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.db.notification import create_notification
from onyx.db.persona import create_assistant_category
@@ -37,11 +34,7 @@ from onyx.db.persona import update_persona_shared_users
from onyx.db.persona import update_persona_visibility
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaCategoryCreate
from onyx.server.features.persona.models import PersonaCategoryResponse
@@ -51,7 +44,6 @@ from onyx.server.features.persona.models import PromptTemplateResponse
from onyx.server.models import DisplayPriorityRequest
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
logger = setup_logger()
@@ -175,25 +167,14 @@ def create_persona(
create_persona_request: CreatePersonaRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> PersonaSnapshot:
persona_snapshot = create_update_persona(
return create_update_persona(
persona_id=None,
create_persona_request=create_persona_request,
user=user,
db_session=db_session,
)
create_milestone_and_report(
user=user,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_ASSISTANT,
properties=None,
db_session=db_session,
)
return persona_snapshot
# NOTE: This endpoint cannot update persona configuration options that
# are core to the persona, such as its display priority and
@@ -382,26 +363,3 @@ def build_final_template_prompt(
retrieval_disabled=retrieval_disabled,
)
)
@basic_router.post("/assistant-prompt-refresh")
def build_assistant_prompts(
generate_persona_prompt_request: GenerateStarterMessageRequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> list[StarterMessage]:
try:
logger.info(
"Generating starter messages for user: %s", user.id if user else "Anonymous"
)
return generate_starter_messages(
name=generate_persona_prompt_request.name,
description=generate_persona_prompt_request.description,
instructions=generate_persona_prompt_request.instructions,
document_set_ids=generate_persona_prompt_request.document_set_ids,
db_session=db_session,
user=user,
)
except Exception as e:
logger.exception("Failed to generate starter messages")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -17,14 +17,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# More minimal request for generating a persona prompt
class GenerateStarterMessageRequest(BaseModel):
name: str
description: str
instructions: str
document_set_ids: list[int]
class CreatePersonaRequest(BaseModel):
name: str
description: str

View File

@@ -4,9 +4,7 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.constants import MilestoneRecordType
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import ChannelConfig
from onyx.db.models import User
@@ -27,7 +25,6 @@ from onyx.server.manage.models import SlackBot
from onyx.server.manage.models import SlackBotCreationRequest
from onyx.server.manage.models import SlackChannelConfig
from onyx.server.manage.models import SlackChannelConfigCreationRequest
from onyx.utils.telemetry import create_milestone_and_report
router = APIRouter(prefix="/manage")
@@ -220,7 +217,6 @@ def create_bot(
slack_bot_creation_request: SlackBotCreationRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> SlackBot:
slack_bot_model = insert_slack_bot(
db_session=db_session,
@@ -229,15 +225,6 @@ def create_bot(
bot_token=slack_bot_creation_request.bot_token,
app_token=slack_bot_creation_request.app_token,
)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_ONYX_BOT,
properties=None,
db_session=db_session,
)
return SlackBot.from_model(slack_bot_model)

View File

@@ -30,12 +30,10 @@ from onyx.chat.prompt_builder.citations_prompt import (
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from onyx.db.chat import add_chats_to_session_from_slack_thread
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import delete_all_chat_sessions_for_user
from onyx.db.chat import delete_chat_session
from onyx.db.chat import duplicate_chat_session_for_user_from_slack
from onyx.db.chat import get_chat_message
@@ -46,9 +44,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.db.models import User
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
from onyx.utils.headers import get_custom_tool_additional_request_headers
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
logger = setup_logger()
@@ -281,17 +276,6 @@ def patch_chat_session(
return None
@router.delete("/delete-all-chat-sessions")
def delete_all_chat_sessions(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
delete_all_chat_sessions_for_user(user=user, db_session=db_session)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/delete-chat-session/{session_id}")
def delete_chat_session_by_id(
session_id: UUID,
@@ -331,9 +315,8 @@ def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_limited_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
_: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
tenant_id: str = Depends(get_current_tenant_id),
) -> StreamingResponse:
"""
This endpoint is both used for all the following purposes:
@@ -364,15 +347,6 @@ def handle_new_chat_message(
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
with get_session_with_tenant(tenant_id) as db_session:
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.RAN_QUERY,
properties=None,
db_session=db_session,
)
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_chat_message(

View File

@@ -11,7 +11,6 @@ from onyx.chat.models import RetrievalDocs
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import SessionType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import RerankingDetails
@@ -152,10 +151,6 @@ class ChatSessionUpdateRequest(BaseModel):
sharing_status: ChatSessionSharedStatus
class DeleteAllSessionsRequest(BaseModel):
session_type: SessionType
class RenameChatSessionResponse(BaseModel):
new_name: str # This is only really useful if the name is generated

View File

@@ -10,17 +10,10 @@ from onyx.configs.app_configs import DISABLE_TELEMETRY
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.configs import MULTI_TENANT
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
_CACHED_UUID: str | None = None
@@ -110,37 +103,3 @@ def optional_telemetry(
except Exception:
# Should never interfere with normal functions of Onyx
pass
def mt_cloud_telemetry(
distinct_id: str,
event: MilestoneRecordType,
properties: dict | None = None,
) -> None:
if not MULTI_TENANT:
return
# MIT version should not need to include any Posthog code
# This is only for Onyx MT Cloud, this code should also never be hit, no reason for any orgs to
# be running the Multi Tenant version of Onyx.
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(distinct_id, event, properties)
def create_milestone_and_report(
user: User | None,
distinct_id: str,
event_type: MilestoneRecordType,
properties: dict | None,
db_session: Session,
) -> None:
_, is_new = create_milestone_if_not_exists(user, event_type, db_session)
if is_new:
mt_cloud_telemetry(
distinct_id=distinct_id,
event=event_type,
properties=properties,
)

View File

@@ -9,7 +9,6 @@ mypy-extensions==1.0.0
mypy==1.8.0
pandas-stubs==2.2.3.241009
pandas==2.2.3
posthog==3.7.4
pre-commit==3.2.2
pytest-asyncio==0.22.0
pytest==7.4.4

View File

@@ -1,3 +1,2 @@
cohere==5.6.1
posthog==3.7.4
python3-saml==1.15.0
cohere==5.6.1

View File

@@ -48,7 +48,4 @@ sleep 1
echo "Running Alembic migration..."
alembic upgrade head
# Run the following instead of the above if using MT cloud
# alembic -n schema_private upgrade head
echo "Containers restarted and migration completed."

View File

@@ -14,7 +14,7 @@ from tests.integration.common_utils.test_models import DATestUser
DOMAIN = "test.com"
DEFAULT_PASSWORD = "TestPassword123!"
DEFAULT_PASSWORD = "test"
def build_email(name: str) -> str:

View File

@@ -219,7 +219,6 @@ def test_slack_permission_sync(
assert private_message not in onyx_doc_message_strings
@pytest.mark.xfail(reason="flaky", strict=False)
def test_slack_group_permission_sync(
reset: None,
vespa_client: vespa_fixture,

View File

@@ -376,26 +376,6 @@ def process_text(
"The code demonstrates variable assignment.",
[],
),
(
"Long JSON string in code block",
[
"```json\n{",
'"name": "John Doe",',
'"age": 30,',
'"city": "New York",',
'"hobbies": ["reading", "swimming", "cycling"],',
'"education": {',
' "degree": "Bachelor\'s",',
' "major": "Computer Science",',
' "university": "Example University"',
"}",
"}\n```",
],
'```json\n{"name": "John Doe","age": 30,"city": "New York","hobbies": '
'["reading", "swimming", "cycling"],"education": { '
'"degree": "Bachelor\'s", "major": "Computer Science", "university": "Example University"}}\n```',
[],
),
(
"Citation as a single token",
[

View File

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-beat
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-heavy
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-indexing
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-light
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-primary
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -92,7 +92,6 @@ services:
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
@@ -104,13 +103,6 @@ services:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
# Seeding configuration
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -231,13 +223,6 @@ services:
# Enterprise Edition stuff
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:

View File

@@ -84,7 +84,6 @@ services:
# (time spent on finding the right docs + time spent fetching summaries from disk)
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
# Chat Configs
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
@@ -92,13 +91,6 @@ services:
# Enterprise Edition only
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -200,13 +192,6 @@ services:
# Enterprise Edition only
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:

View File

@@ -22,13 +22,6 @@ services:
- VESPA_HOST=index
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -59,13 +52,6 @@ services:
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:

View File

@@ -23,13 +23,6 @@ services:
- VESPA_HOST=index
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -64,13 +57,6 @@ services:
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -237,7 +223,7 @@ services:
volumes:
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
logging::wq
logging:
driver: json-file
options:
max-size: "50m"
@@ -259,6 +245,3 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:

View File

@@ -60,12 +60,3 @@ spec:
envFrom:
- configMapRef:
name: env-configmap
# Uncomment if you are using IAM auth for Postgres
# volumeMounts:
# - name: bundle-pem
# mountPath: "/app/certs"
# readOnly: true
# volumes:
# - name: bundle-pem
# secret:
# secretName: bundle-pem-secret

View File

@@ -43,7 +43,6 @@ spec:
# - name: my-ca-cert-volume
# mountPath: /etc/ssl/certs/custom-ca.crt
# subPath: my-ca.crt
# Optional volume for CA certificate
# volumes:
# - name: my-cas-cert-volume
@@ -52,13 +51,3 @@ spec:
# items:
# - key: my-ca.crt
# path: my-ca.crt
# Uncomment if you are using IAM auth for Postgres
# volumeMounts:
# - name: bundle-pem
# mountPath: "/app/certs"
# readOnly: true
# volumes:
# - name: bundle-pem
# secret:
# secretName: bundle-pem-secret

View File

@@ -75,8 +75,7 @@ export default function Page() {
},
{} as Record<SourceCategory, SourceMetadata[]>
);
}, [sources, filterSources, searchTerm]);
}, [sources, searchTerm]);
const handleKeyPress = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === "Enter") {
const filteredCategories = Object.entries(categorizedSources).filter(

View File

@@ -1,7 +1,7 @@
"use client";
import { Option } from "@/components/Dropdown";
import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils";
import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types";
import { Separator } from "@/components/ui/separator";
import { Button } from "@/components/ui/button";
@@ -9,11 +9,12 @@ import { Textarea } from "@/components/ui/textarea";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
FormikProps,
useFormikContext,
} from "formik";
import {
@@ -26,6 +27,7 @@ import {
import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
@@ -39,9 +41,10 @@ import {
} from "@/components/ui/tooltip";
import Link from "next/link";
import { useRouter } from "next/navigation";
import { useEffect, useMemo, useState } from "react";
import { FiInfo, FiRefreshCcw } from "react-icons/fi";
import { useEffect, useState } from "react";
import { FiInfo, FiX } from "react-icons/fi";
import * as Yup from "yup";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
@@ -63,9 +66,6 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { buildImgUrl } from "@/app/chat/files/images/utils";
import { LlmList } from "@/components/llm/LLMList";
import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList";
import { Input } from "@/components/ui/input";
import { CategoryCard } from "./CategoryCard";
@@ -129,14 +129,12 @@ export function AssistantEditor({
];
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const [hasEditedStarterMessage, setHasEditedStarterMessage] = useState(false);
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
// state to persist across formik reformatting
const [defautIconColor, _setDeafultIconColor] = useState(
colorOptions[Math.floor(Math.random() * colorOptions.length)]
);
const [isRefreshing, setIsRefreshing] = useState(false);
const [defaultIconShape, setDefaultIconShape] = useState<any>(null);
@@ -150,10 +148,6 @@ export function AssistantEditor({
const [removePersonaImage, setRemovePersonaImage] = useState(false);
const autoStarterMessageEnabled = useMemo(
() => llmProviders.length > 0,
[llmProviders.length]
);
const isUpdate = existingPersona !== undefined && existingPersona !== null;
const existingPrompt = existingPersona?.prompts[0] ?? null;
const defaultProvider = llmProviders.find(
@@ -223,24 +217,7 @@ export function AssistantEditor({
existingPersona?.llm_model_provider_override ?? null,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
starter_messages: existingPersona?.starter_messages ?? [
{
name: "",
message: "",
},
{
name: "",
message: "",
},
{
name: "",
message: "",
},
{
name: "",
message: "",
},
],
starter_messages: existingPersona?.starter_messages ?? [],
enabled_tools_map: enabledToolsMap,
icon_color: existingPersona?.icon_color ?? defautIconColor,
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
@@ -251,44 +228,6 @@ export function AssistantEditor({
groups: existingPersona?.groups ?? [],
};
interface AssistantPrompt {
message: string;
name: string;
}
const debouncedRefreshPrompts = debounce(
async (values: any, setFieldValue: any) => {
if (!autoStarterMessageEnabled) {
return;
}
setIsRefreshing(true);
try {
const response = await fetch("/api/persona/assistant-prompt-refresh", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
name: values.name,
description: values.description,
document_set_ids: values.document_set_ids,
instructions: values.system_prompt || values.task_prompt,
}),
});
const data: AssistantPrompt = await response.json();
if (response.ok) {
setFieldValue("starter_messages", data);
}
} catch (error) {
console.error("Failed to refresh prompts:", error);
} finally {
setIsRefreshing(false);
}
},
1000
);
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
return (
@@ -482,8 +421,6 @@ export function AssistantEditor({
isSubmitting,
values,
setFieldValue,
errors,
...formikProps
}: FormikProps<any>) => {
function toggleToolInValues(toolId: number) {
@@ -508,7 +445,6 @@ export function AssistantEditor({
return (
<Form className="w-full text-text-950">
{/* Refresh starter messages when name or description changes */}
<div className="w-full flex gap-x-2 justify-center">
<Popover
open={isIconDropdownOpen}
@@ -1048,91 +984,6 @@ export function AssistantEditor({
</div>
</div>
<div className="mb-6 w-full flex flex-col">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
Starter Messages
</div>
</div>
<SubLabel>
Pre-configured messages that help users understand what this
assistant can do and how to interact with it effectively.
</SubLabel>
<div className="relative w-fit">
<TooltipProvider delayDuration={50}>
<Tooltip>
<TooltipTrigger asChild>
<div>
<Button
type="button"
size="sm"
onClick={() =>
debouncedRefreshPrompts(values, setFieldValue)
}
disabled={
!autoStarterMessageEnabled ||
isRefreshing ||
(Object.keys(errors).length > 0 &&
Object.keys(errors).some(
(key) => !key.startsWith("starter_messages")
))
}
className={`
px-3 py-2
mr-auto
my-2
flex gap-x-2
text-sm font-medium
rounded-lg shadow-sm
items-center gap-2
transition-colors duration-200
${
isRefreshing || !autoStarterMessageEnabled
? "bg-gray-100 text-gray-400 cursor-not-allowed"
: "bg-blue-50 text-blue-600 hover:bg-blue-100 active:bg-blue-200"
}
`}
>
<div className="flex items-center gap-x-2">
{isRefreshing ? (
<FiRefreshCcw className="w-4 h-4 animate-spin text-gray-400" />
) : (
<SwapIcon className="w-4 h-4 text-blue-600" />
)}
Generate
</div>
</Button>
</div>
</TooltipTrigger>
{!autoStarterMessageEnabled && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
No LLM providers configured. Generation is not
available.
</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
</div>
<div className="w-full">
<FieldArray
name="starter_messages"
render={(arrayHelpers: ArrayHelpers) => (
<StarterMessagesList
isRefreshing={isRefreshing}
values={values.starter_messages}
arrayHelpers={arrayHelpers}
touchStarterMessages={() => {
setHasEditedStarterMessage(true);
}}
/>
)}
/>
</div>
</div>
{admin && (
<AdvancedOptionsToggle
title="Categories"
@@ -1339,12 +1190,136 @@ export function AssistantEditor({
</>
)}
<div className="mb-6 flex flex-col">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
Starter Messages (Optional){" "}
</div>
</div>
<SubLabel>
Add pre-defined messages to help users get started. Only
the first 4 will be displayed.
</SubLabel>
<FieldArray
name="starter_messages"
render={(
arrayHelpers: ArrayHelpers<StarterMessage[]>
) => (
<div>
{values.starter_messages &&
values.starter_messages.length > 0 &&
values.starter_messages.map(
(
starterMessage: StarterMessage,
index: number
) => {
return (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label small>Name</Label>
<SubLabel>
Shows up as the &quot;title&quot;
for this Starter Message. For
example, &quot;Write an email&quot;.
</SubLabel>
<Field
name={`starter_messages[${index}].name`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages[${index}].name`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label small>Message</Label>
<SubLabel>
The actual message to be sent as the
initial user message if a user
selects this starter prompt. For
example, &quot;Write me an email to
a client about a new billing feature
we just released.&quot;
</SubLabel>
<Field
name={`starter_messages[${index}].message`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
min-h-12
mr-4
line-clamp-
`}
as="textarea"
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages[${index}].message`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() =>
arrayHelpers.remove(index)
}
/>
</div>
</div>
</div>
);
}
)}
<Button
onClick={() => {
arrayHelpers.push({
name: "",
description: "",
message: "",
});
}}
className="mt-3"
size="sm"
variant="next"
>
Add New
</Button>
</div>
)}
/>
</div>
<IsPublicGroupSelector
formikProps={{
values,
isSubmitting,
setFieldValue,
errors,
...formikProps,
}}
objectName="assistant"

View File

@@ -1,198 +0,0 @@
"use client";
import { ArrayHelpers, ErrorMessage, Field, useFormikContext } from "formik";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@radix-ui/react-tooltip";
import { useEffect } from "react";
import { FiInfo, FiTrash2, FiPlus } from "react-icons/fi";
import { StarterMessage } from "./interfaces";
import { Label } from "@/components/admin/connectors/Field";
export default function StarterMessagesList({
values,
arrayHelpers,
isRefreshing,
touchStarterMessages,
}: {
values: StarterMessage[];
arrayHelpers: ArrayHelpers;
isRefreshing: boolean;
touchStarterMessages: () => void;
}) {
const { handleChange } = useFormikContext();
// Group starter messages into rows of 2 for display purposes
const rows = values.reduce((acc: StarterMessage[][], curr, i) => {
if (i % 2 === 0) acc.push([curr]);
else acc[acc.length - 1].push(curr);
return acc;
}, []);
const canAddMore = values.length <= 6;
return (
<div className="mt-4 flex flex-col gap-6">
{rows.map((row, rowIndex) => (
<div key={rowIndex} className="flex items-start gap-4">
<div className="grid grid-cols-2 gap-6 w-full xl:w-fit">
{row.map((starterMessage, colIndex) => (
<div
key={rowIndex * 2 + colIndex}
className="bg-white max-w-full w-full xl:w-[500px] border border-border rounded-lg shadow-md transition-shadow duration-200 p-6"
>
<div className="space-y-5">
{isRefreshing ? (
<div className="w-full">
<div className="w-full">
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
</div>
<div>
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
</div>
<div>
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
<div className="h-24 w-full bg-gray-200 rounded animate-pulse" />
</div>
</div>
) : (
<>
<div>
<div className="flex w-full items-center gap-x-1">
<Label
small
className="text-sm font-medium text-gray-700"
>
Name
</Label>
<TooltipProvider delayDuration={50}>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
Shows up as the &quot;title&quot; for this
Starter Message. For example, &quot;Write an
email.&quot;
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Field
name={`starter_messages.${
rowIndex * 2 + colIndex
}.name`}
className="mt-1 w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition"
autoComplete="off"
placeholder="Enter a name..."
onChange={(e: any) => {
touchStarterMessages();
handleChange(e);
}}
/>
<ErrorMessage
name={`starter_messages.${
rowIndex * 2 + colIndex
}.name`}
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<div>
<div className="flex w-full items-center gap-x-1">
<Label
small
className="text-sm font-medium text-gray-700"
>
Message
</Label>
<TooltipProvider delayDuration={50}>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
The actual message to be sent as the initial
user message.
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Field
name={`starter_messages.${
rowIndex * 2 + colIndex
}.message`}
className="mt-1 text-sm w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition min-h-[100px] resize-y"
as="textarea"
autoComplete="off"
placeholder="Enter the message..."
onChange={(e: any) => {
touchStarterMessages();
handleChange(e);
}}
/>
<ErrorMessage
name={`starter_messages.${
rowIndex * 2 + colIndex
}.message`}
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
</>
)}
</div>
</div>
))}
</div>
<button
type="button"
onClick={() => {
arrayHelpers.remove(rowIndex * 2 + 1);
arrayHelpers.remove(rowIndex * 2);
}}
className="p-1.5 bg-white border border-gray-200 rounded-full text-gray-400 hover:text-red-500 hover:border-red-200 transition-colors mt-2"
aria-label="Delete row"
>
<FiTrash2 size={14} />
</button>
</div>
))}
{canAddMore && (
<button
type="button"
onClick={() => {
arrayHelpers.push({
name: "",
message: "",
});
arrayHelpers.push({
name: "",
message: "",
});
}}
className="self-start flex items-center gap-2 px-4 py-2 bg-white border border-gray-200 rounded-lg text-gray-600 hover:bg-gray-50 hover:border-gray-300 transition-colors"
>
<FiPlus size={16} />
<span>Add Row</span>
</button>
)}
</div>
);
}

View File

@@ -2,7 +2,7 @@ import { useFormContext } from "@/components/context/FormContext";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { SettingsIcon } from "@/components/icons/icons";
import { Logo } from "@/components/logo/Logo";
import { Logo } from "@/components/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { credentialTemplates } from "@/lib/connectors/credentials";
import Link from "next/link";
@@ -62,11 +62,11 @@ export default function Sidebar() {
];
return (
<div className="flex flex-none w-[250px] text-default">
<div className="flex flex-none w-[250px] bg-background text-default">
<div
className={`
fixed
bg-background-sidebar
bg-background-100
h-screen
transition-all
bg-opacity-80

View File

@@ -326,9 +326,8 @@ export function CCPairIndexingStatusTable({
(sum, status) => sum + status.docs_indexed,
0
),
errors: statuses.filter(
(status) => status.last_finished_status === "failed"
).length,
errors: statuses.filter((status) => status.last_status === "failed")
.length,
};
});

View File

@@ -20,22 +20,26 @@ import { useRouter } from "next/navigation";
import { pageType } from "../chat/sessionSidebar/types";
import FixedLogo from "../chat/shared_chat_search/FixedLogo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { useChatContext } from "@/components/context/ChatContext";
interface SidebarWrapperProps<T extends object> {
chatSessions?: ChatSession[];
folders?: Folder[];
initiallyToggled: boolean;
openedFolders?: { [key: number]: boolean };
page: pageType;
size?: "sm" | "lg";
children: ReactNode;
}
export default function SidebarWrapper<T extends object>({
chatSessions,
initiallyToggled,
folders,
openedFolders,
page,
size = "sm",
children,
}: SidebarWrapperProps<T>) {
const { chatSessions, folders, openedFolders } = useChatContext();
const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled);
const [showDocSidebar, setShowDocSidebar] = useState(false); // State to track if sidebar is open
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
@@ -128,7 +132,7 @@ export default function SidebarWrapper<T extends object>({
</div>
</div>
<div className="absolute h-svh px-2 left-0 w-full top-0">
<div className="absolute h-svh left-0 w-full top-0">
<FunctionalHeader
sidebarToggled={toggledSidebar}
toggleSidebar={toggleSidebar}

View File

@@ -1,15 +1,31 @@
"use client";
import SidebarWrapper from "../SidebarWrapper";
import { ChatSession } from "@/app/chat/interfaces";
import { Folder } from "@/app/chat/folders/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces";
import { User } from "@/lib/types";
import { AssistantsGallery } from "./AssistantsGallery";
export default function WrappedAssistantsGallery({
toggleSidebar,
chatSessions,
initiallyToggled,
folders,
openedFolders,
}: {
toggleSidebar: boolean;
chatSessions: ChatSession[];
folders: Folder[];
initiallyToggled: boolean;
openedFolders?: { [key: number]: boolean };
}) {
return (
<SidebarWrapper page="chat" initiallyToggled={toggleSidebar}>
<SidebarWrapper
page="chat"
initiallyToggled={initiallyToggled}
chatSessions={chatSessions}
folders={folders}
openedFolders={openedFolders}
>
<AssistantsGallery />
</SidebarWrapper>
);

View File

@@ -5,7 +5,6 @@ import { unstable_noStore as noStore } from "next/cache";
import { redirect } from "next/navigation";
import WrappedAssistantsGallery from "./WrappedAssistantsGallery";
import { cookies } from "next/headers";
import { ChatProvider } from "@/components/context/ChatContext";
export default async function GalleryPage(props: {
searchParams: Promise<{ [key: string]: string }>;
@@ -27,38 +26,22 @@ export default async function GalleryPage(props: {
openedFolders,
toggleSidebar,
shouldShowWelcomeModal,
availableSources,
ccPairs,
documentSets,
tags,
llmProviders,
defaultAssistantId,
} = data;
return (
<ChatProvider
value={{
chatSessions,
availableSources,
ccPairs,
documentSets,
tags,
availableDocumentSets: documentSets,
availableTags: tags,
llmProviders,
folders,
openedFolders,
shouldShowWelcomeModal,
defaultAssistantId,
}}
>
<>
{shouldShowWelcomeModal && (
<WelcomeModal user={user} requestCookies={requestCookies} />
)}
<InstantSSRAutoRefresh />
<WrappedAssistantsGallery toggleSidebar={toggleSidebar} />
</ChatProvider>
<WrappedAssistantsGallery
initiallyToggled={toggleSidebar}
chatSessions={chatSessions}
folders={folders}
openedFolders={openedFolders}
/>
</>
);
}

View File

@@ -1,14 +1,28 @@
"use client";
import { AssistantsList } from "./AssistantsList";
import SidebarWrapper from "../SidebarWrapper";
import { ChatSession } from "@/app/chat/interfaces";
import { Folder } from "@/app/chat/folders/interfaces";
export default function WrappedAssistantsMine({
chatSessions,
initiallyToggled,
folders,
openedFolders,
}: {
chatSessions: ChatSession[];
folders: Folder[];
initiallyToggled: boolean;
openedFolders?: { [key: number]: boolean };
}) {
return (
<SidebarWrapper page="chat" initiallyToggled={initiallyToggled}>
<SidebarWrapper
page="chat"
initiallyToggled={initiallyToggled}
chatSessions={chatSessions}
folders={folders}
openedFolders={openedFolders}
>
<AssistantsList />
</SidebarWrapper>
);

View File

@@ -6,7 +6,6 @@ import { redirect } from "next/navigation";
import WrappedAssistantsMine from "./WrappedAssistantsMine";
import { WelcomeModal } from "@/components/initialSetup/welcome/WelcomeModalWrapper";
import { cookies } from "next/headers";
import { ChatProvider } from "@/components/context/ChatContext";
export default async function GalleryPage(props: {
searchParams: Promise<{ [key: string]: string }>;
@@ -28,37 +27,21 @@ export default async function GalleryPage(props: {
openedFolders,
toggleSidebar,
shouldShowWelcomeModal,
availableSources,
ccPairs,
documentSets,
tags,
llmProviders,
defaultAssistantId,
} = data;
return (
<ChatProvider
value={{
chatSessions,
availableSources,
ccPairs,
documentSets,
tags,
availableDocumentSets: documentSets,
availableTags: tags,
llmProviders,
folders,
openedFolders,
shouldShowWelcomeModal,
defaultAssistantId,
}}
>
<>
{shouldShowWelcomeModal && (
<WelcomeModal user={user} requestCookies={requestCookies} />
)}
<InstantSSRAutoRefresh />
<WrappedAssistantsMine initiallyToggled={toggleSidebar} />
</ChatProvider>
<WrappedAssistantsMine
initiallyToggled={toggleSidebar}
chatSessions={chatSessions}
folders={folders}
openedFolders={openedFolders}
/>
</>
);
}

View File

@@ -9,7 +9,6 @@ import * as Yup from "yup";
import { requestEmailVerification } from "../lib";
import { useState } from "react";
import { Spinner } from "@/components/Spinner";
import { set } from "lodash";
export function EmailPasswordForm({
isSignup = false,
@@ -48,12 +47,10 @@ export function EmailPasswordForm({
);
if (!response.ok) {
setIsWorking(false);
const errorDetail = (await response.json()).detail;
let errorMsg = "Unknown error";
if (typeof errorDetail === "object" && errorDetail.reason) {
errorMsg = errorDetail.reason;
} else if (errorDetail === "REGISTER_USER_ALREADY_EXISTS") {
if (errorDetail === "REGISTER_USER_ALREADY_EXISTS") {
errorMsg =
"An account already exists with the specified email.";
}

View File

@@ -6,7 +6,7 @@ import { useCallback, useEffect, useState } from "react";
import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
import { User } from "@/lib/types";
import { Logo } from "@/components/logo/Logo";
import { Logo } from "@/components/Logo";
export function Verify({ user }: { user: User | null }) {
const searchParams = useSearchParams();

View File

@@ -8,7 +8,7 @@ import { HealthCheckBanner } from "@/components/health/healthcheck";
import { User } from "@/lib/types";
import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "./RequestNewVerificationEmail";
import { Logo } from "@/components/logo/Logo";
import { Logo } from "@/components/Logo";
export default async function Page() {
// catch cases where the backend is completely unreachable here

View File

@@ -27,7 +27,6 @@ import {
buildLatestMessageChain,
checkAnyAssistantHasSearch,
createChatSession,
deleteAllChatSessions,
deleteChatSession,
getCitedDocumentsFromMessage,
getHumanAndAIMessageFromMessageNumber,
@@ -273,7 +272,6 @@ export function ChatPage({
};
const llmOverrideManager = useLlmOverride(
llmProviders,
modelVersionFromSearchParams || (user?.preferences.default_model ?? null),
selectedChatSession,
defaultTemperature
@@ -320,9 +318,9 @@ export function ChatPage({
);
if (personaDefault) {
llmOverrideManager.updateLLMOverride(personaDefault);
llmOverrideManager.setLlmOverride(personaDefault);
} else if (user?.preferences.default_model) {
llmOverrideManager.updateLLMOverride(
llmOverrideManager.setLlmOverride(
destructureValue(user?.preferences.default_model)
);
}
@@ -1204,6 +1202,7 @@ export function ChatPage({
assistant_message_id: number;
frozenMessageMap: Map<number, Message>;
} = null;
try {
const mapKeys = Array.from(
currentMessageMap(completeMessageDetail).keys()
@@ -1838,7 +1837,6 @@ export function ChatPage({
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
const [settingsToggled, setSettingsToggled] = useState(false);
const [showDeleteAllModal, setShowDeleteAllModal] = useState(false);
const currentPersona = alternativeAssistant || liveAssistant;
useEffect(() => {
@@ -1905,6 +1903,11 @@ export function ChatPage({
const showShareModal = (chatSession: ChatSession) => {
setSharedChatSession(chatSession);
};
const [documentSelection, setDocumentSelection] = useState(false);
// const toggleDocumentSelectionAspects = () => {
// setDocumentSelection((documentSelection) => !documentSelection);
// setShowDocSidebar(false);
// };
const toggleDocumentSidebar = () => {
if (!documentSidebarToggled) {
@@ -1969,32 +1972,6 @@ export function ChatPage({
<ChatPopup />
{showDeleteAllModal && (
<DeleteEntityModal
entityType="All Chats"
entityName="all your chat sessions"
onClose={() => setShowDeleteAllModal(false)}
additionalDetails="This action cannot be undone. All your chat sessions will be deleted."
onSubmit={async () => {
const response = await deleteAllChatSessions("Chat");
if (response.ok) {
setShowDeleteAllModal(false);
setPopup({
message: "All your chat sessions have been deleted.",
type: "success",
});
refreshChatSessions();
router.push("/chat");
} else {
setPopup({
message: "Failed to delete all chat sessions.",
type: "error",
});
}
}}
/>
)}
{currentFeedback && (
<FeedbackModal
feedbackType={currentFeedback[0]}
@@ -2146,7 +2123,7 @@ export function ChatPage({
page="chat"
ref={innerSidebarElementRef}
toggleSidebar={toggleSidebar}
toggled={toggledSidebar}
toggled={toggledSidebar && !settings?.isMobile}
backgroundToggled={toggledSidebar || showHistorySidebar}
existingChats={chatSessions}
currentChatSession={selectedChatSession}
@@ -2155,7 +2132,6 @@ export function ChatPage({
removeToggle={removeToggle}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
showDeleteAllModal={() => setShowDeleteAllModal(true)}
/>
</div>
</div>
@@ -2168,6 +2144,7 @@ export function ChatPage({
fixed
right-0
z-[1000]
bg-background
h-screen
transition-all
@@ -2217,6 +2194,8 @@ export function ChatPage({
{liveAssistant && (
<FunctionalHeader
toggleUserSettings={() => setUserSettingsToggled(true)}
liveAssistant={liveAssistant}
onAssistantChange={onAssistantChange}
sidebarToggled={toggledSidebar}
reset={() => setMessage("")}
page="chat"
@@ -2228,6 +2207,7 @@ export function ChatPage({
toggleSidebar={toggleSidebar}
currentChatSession={selectedChatSession}
documentSidebarToggled={documentSidebarToggled}
llmOverrideManager={llmOverrideManager}
/>
)}
@@ -2762,10 +2742,6 @@ export function ChatPage({
removeDocs={() => {
clearSelectedDocuments();
}}
showDocs={() => {
setFiltersToggled(false);
setDocumentSidebarToggled(true);
}}
removeFilters={() => {
filterManager.setSelectedSources([]);
filterManager.setSelectedTags([]);
@@ -2778,6 +2754,7 @@ export function ChatPage({
chatState={currentSessionChatState}
stopGenerating={stopGenerating}
openModelSettings={() => setSettingsToggled(true)}
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
// assistant stuff
selectedAssistant={liveAssistant}

View File

@@ -124,9 +124,9 @@ export default function RegenerateOption({
onHoverChange: (isHovered: boolean) => void;
onDropdownVisibleChange: (isVisible: boolean) => void;
}) {
const { llmProviders } = useChatContext();
const llmOverrideManager = useLlmOverride(llmProviders);
const llmOverrideManager = useLlmOverride();
const { llmProviders } = useChatContext();
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
const llmOptionsByProvider: {

View File

@@ -81,8 +81,6 @@ export function ChatDocumentDisplay({
}
};
const hasMetadata =
document.updated_at || Object.keys(document.metadata).length > 0;
return (
<div className={`opacity-100 ${modal ? "w-[90vw]" : "w-full"}`}>
<div
@@ -109,14 +107,8 @@ export function ChatDocumentDisplay({
: document.semantic_identifier || document.document_id}
</div>
</div>
{hasMetadata && (
<DocumentMetadataBlock modal={modal} document={document} />
)}
<div
className={`line-clamp-3 text-sm font-normal leading-snug text-gray-600 ${
hasMetadata ? "mt-2" : ""
}`}
>
<DocumentMetadataBlock modal={modal} document={document} />
<div className="line-clamp-3 pt-2 text-sm font-normal leading-snug text-gray-600">
{buildDocumentSummaryDisplay(
document.match_highlights,
document.blurb

View File

@@ -95,7 +95,7 @@ const FolderItem = ({
if (!continueEditing) {
setIsEditing(false);
}
router.refresh(); // Refresh values to update the sidebar
router.refresh();
} catch (error) {
setPopup({ message: "Failed to save folder name", type: "error" });
}

View File

@@ -31,7 +31,14 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
import { ChatState } from "../types";
import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText";
import { useAssistants } from "@/components/context/AssistantsContext";
import AnimatedToggle from "@/components/search/SearchBar";
import { Popup } from "@/components/admin/connectors/Popup";
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
import { IconType } from "react-icons";
import { LlmTab } from "../modal/configuration/LlmTab";
import { XIcon } from "lucide-react";
import { FilterPills } from "./FilterPills";
import { Tag } from "@/lib/types";
import FiltersDisplay from "./FilterDisplay";
const MAX_INPUT_HEIGHT = 200;
@@ -40,6 +47,7 @@ interface ChatInputBarProps {
removeFilters: () => void;
removeDocs: () => void;
openModelSettings: () => void;
showDocs: () => void;
showConfigureAPIKey: () => void;
selectedDocuments: OnyxDocument[];
message: string;
@@ -49,7 +57,6 @@ interface ChatInputBarProps {
filterManager: FilterManager;
llmOverrideManager: LlmOverrideManager;
chatState: ChatState;
showDocs: () => void;
alternativeAssistant: Persona | null;
// assistants
selectedAssistant: Persona;
@@ -68,8 +75,8 @@ export function ChatInputBar({
removeFilters,
removeDocs,
openModelSettings,
showConfigureAPIKey,
showDocs,
showConfigureAPIKey,
selectedDocuments,
message,
setMessage,
@@ -277,6 +284,10 @@ export function ChatInputBar({
</div>
)}
{/* <div>
<SelectedFilterDisplay filterManager={filterManager} />
</div> */}
<UnconfiguredProviderText showConfigureAPIKey={showConfigureAPIKey} />
<div
@@ -417,7 +428,9 @@ export function ChatInputBar({
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder="Ask me anything.."
placeholder={`Send a message ${
!settings?.isMobile ? "or try using @ or /" : ""
}`}
value={message}
onKeyDown={(event) => {
if (

View File

@@ -278,16 +278,6 @@ export async function deleteChatSession(chatSessionId: string) {
return response;
}
export async function deleteAllChatSessions(sessionType: "Chat" | "Search") {
const response = await fetch(`/api/chat/delete-all-chat-sessions`, {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
});
return response;
}
export async function* simulateLLMResponse(input: string, delay: number = 30) {
// Split the input string into tokens. This is a simple example, and in real use case, tokenization can be more complex.
// Iterate over tokens and yield them one by one

View File

@@ -812,7 +812,6 @@ export const HumanMessage = ({
outline-none
placeholder-gray-400
resize-none
text-text-editing-message
pl-4
overflow-y-auto
pr-12
@@ -871,6 +870,7 @@ export const HumanMessage = ({
py-2
px-3
w-fit
bg-hover
bg-background-strong
text-sm
rounded-lg
@@ -896,13 +896,15 @@ export const HumanMessage = ({
<TooltipProvider delayDuration={1000}>
<Tooltip>
<TooltipTrigger>
<HoverableIcon
icon={<FiEdit2 className="text-gray-600" />}
<button
className="hover:bg-hover p-1.5 rounded"
onClick={() => {
setIsEditing(true);
setIsHovered(false);
}}
/>
>
<FiEdit2 className="!h-4 !w-4" />
</button>
</TooltipTrigger>
<TooltipContent>Edit</TooltipContent>
</Tooltip>

View File

@@ -35,7 +35,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext();
const { updateLLMOverride, temperature, updateTemperature } =
const { setLlmOverride, temperature, updateTemperature } =
llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
@@ -60,7 +60,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
if (value == null) {
return;
}
updateLLMOverride(destructureValue(value));
setLlmOverride(destructureValue(value));
if (chatSessionId) {
updateModelOverrideForChatSession(chatSessionId, value as string);
}

View File

@@ -11,10 +11,13 @@ import { createFolder } from "../folders/FolderManagement";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { AssistantsIconSkeleton } from "@/components/icons/icons";
import {
AssistantsIconSkeleton,
ClosedBookIcon,
} from "@/components/icons/icons";
import { PagesTab } from "./PagesTab";
import { pageType } from "./types";
import LogoWithText from "@/components/header/LogoWithText";
import LogoType from "@/components/header/LogoType";
interface HistorySidebarProps {
page: pageType;
@@ -30,7 +33,6 @@ interface HistorySidebarProps {
showDeleteModal?: (chatSession: ChatSession) => void;
stopGenerating?: () => void;
explicitlyUntoggle: () => void;
showDeleteAllModal?: () => void;
backgroundToggled?: boolean;
}
@@ -50,7 +52,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
stopGenerating = () => null,
showShareModal,
showDeleteModal,
showDeleteAllModal,
backgroundToggled,
},
ref: ForwardedRef<HTMLDivElement>
@@ -99,19 +100,16 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
flex
flex-col relative
h-screen
pt-2
transition-transform
`}
>
<div className="pl-2">
<LogoWithText
showArrow={true}
toggled={toggled}
page={page}
toggleSidebar={toggleSidebar}
explicitlyUntoggle={explicitlyUntoggle}
/>
</div>
<LogoType
showArrow={true}
toggled={toggled}
page={page}
toggleSidebar={toggleSidebar}
explicitlyUntoggle={explicitlyUntoggle}
/>
{page == "chat" && (
<div className="mx-3 mt-4 gap-y-1 flex-col text-text-history-sidebar-button flex gap-x-1.5 items-center items-center">
<Link
@@ -178,7 +176,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
currentChatId={currentChatId}
folders={folders}
openedFolders={openedFolders}
showDeleteAllModal={showDeleteAllModal}
/>
</div>
</>

View File

@@ -9,8 +9,6 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { useRouter } from "next/navigation";
import { useState } from "react";
import { pageType } from "./types";
import { FiTrash2 } from "react-icons/fi";
import { NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED } from "@/lib/constants";
export function PagesTab({
page,
@@ -22,7 +20,6 @@ export function PagesTab({
newFolderId,
showShareModal,
showDeleteModal,
showDeleteAllModal,
}: {
page: pageType;
existingChats?: ChatSession[];
@@ -33,7 +30,6 @@ export function PagesTab({
newFolderId: number | null;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
showDeleteAllModal?: () => void;
}) {
const groupedChatSessions = existingChats
? groupSessionsByDateRange(existingChats)
@@ -67,98 +63,82 @@ export function PagesTab({
const isHistoryEmpty = !existingChats || existingChats.length === 0;
return (
<div className="flex flex-col relative h-full overflow-y-auto mb-1 ml-3 miniscroll mobile:pb-40">
<div
className={` flex-grow overflow-y-auto ${
NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED && "pb-20 "
}`}
>
{folders && folders.length > 0 && (
<div className="py-2 border-b border-border">
<div className="text-xs text-subtle flex pb-0.5 mb-1.5 mt-2 font-bold">
Chat Folders
</div>
<FolderList
newFolderId={newFolderId}
folders={folders}
currentChatId={currentChatId}
openedFolders={openedFolders}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
/>
<div className="mb-1 text-text-sidebar ml-3 relative miniscroll mobile:pb-40 overflow-y-auto h-full">
{folders && folders.length > 0 && (
<div className="py-2 border-b border-border">
<div className="text-xs text-subtle flex pb-0.5 mb-1.5 mt-2 font-bold">
Chat Folders
</div>
)}
<div
onDragOver={(event) => {
event.preventDefault();
setIsDragOver(true);
}}
onDragLeave={() => setIsDragOver(false)}
onDrop={handleDropToRemoveFromFolder}
className={`pt-1 transition duration-300 ease-in-out mr-3 ${
isDragOver ? "bg-hover" : ""
} rounded-md`}
>
{(page == "chat" || page == "search") && (
<p className="my-2 text-xs text-sidebar-subtle flex font-bold">
{page == "chat" && "Chat "}
{page == "search" && "Search "}
History
</p>
)}
{isHistoryEmpty ? (
<p className="text-sm mt-2 w-[250px]">
Try sending a message! Your chat history will appear here.
</p>
) : (
Object.entries(groupedChatSessions).map(
([dateRange, chatSessions], ind) => {
if (chatSessions.length > 0) {
return (
<div key={dateRange}>
<div
className={`text-xs text-text-sidebar-subtle ${
ind != 0 && "mt-5"
} flex pb-0.5 mb-1.5 font-medium`}
>
{dateRange}
</div>
{chatSessions
.filter((chat) => chat.folder_id === null)
.map((chat) => {
const isSelected = currentChatId === chat.id;
return (
<div key={`${chat.id}-${chat.name}`}>
<ChatSessionDisplay
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}
closeSidebar={closeSidebar}
search={page == "search"}
chatSession={chat}
isSelected={isSelected}
skipGradient={isDragOver}
/>
</div>
);
})}
</div>
);
}
}
)
)}
<FolderList
newFolderId={newFolderId}
folders={folders}
currentChatId={currentChatId}
openedFolders={openedFolders}
showShareModal={showShareModal}
showDeleteModal={showDeleteModal}
/>
</div>
{showDeleteAllModal && NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED && (
<div className="absolute w-full border-t border-t-border bg-background-100 bottom-0 left-0 p-4">
<button
className="w-full py-2 px-4 text-text-600 hover:text-text-800 bg-background-125 border border-border-strong/50 shadow-sm rounded-md transition-colors duration-200 flex items-center justify-center text-sm"
onClick={showDeleteAllModal}
>
<FiTrash2 className="mr-2" size={14} />
Clear All History
</button>
</div>
)}
<div
onDragOver={(event) => {
event.preventDefault();
setIsDragOver(true);
}}
onDragLeave={() => setIsDragOver(false)}
onDrop={handleDropToRemoveFromFolder}
className={`pt-1 transition duration-300 ease-in-out mr-3 ${
isDragOver ? "bg-hover" : ""
} rounded-md`}
>
{(page == "chat" || page == "search") && (
<p className="my-2 text-xs text-sidebar-subtle flex font-bold">
{page == "chat" && "Chat "}
{page == "search" && "Search "}
History
</p>
)}
{isHistoryEmpty ? (
<p className="text-sm mt-2 w-[250px]">
{page === "search"
? "Try running a search! Your search history will appear here."
: "Try sending a message! Your chat history will appear here."}
</p>
) : (
Object.entries(groupedChatSessions).map(
([dateRange, chatSessions], ind) => {
if (chatSessions.length > 0) {
return (
<div key={dateRange}>
<div
className={`text-xs text-text-sidebar-subtle ${
ind != 0 && "mt-5"
} flex pb-0.5 mb-1.5 font-medium`}
>
{dateRange}
</div>
{chatSessions
.filter((chat) => chat.folder_id === null)
.map((chat) => {
const isSelected = currentChatId === chat.id;
return (
<div key={`${chat.id}-${chat.name}`}>
<ChatSessionDisplay
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}
closeSidebar={closeSidebar}
search={page == "search"}
chatSession={chat}
isSelected={isSelected}
skipGradient={isDragOver}
/>
</div>
);
})}
</div>
);
}
}
)
)}
</div>
</div>

View File

@@ -1,83 +1,53 @@
"use client";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/logo/Logo";
import { Logo } from "@/components/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import Link from "next/link";
import { useContext } from "react";
import { FiSidebar } from "react-icons/fi";
import { LogoType } from "@/components/logo/Logo";
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
import { useRouter } from "next/navigation";
export function LogoComponent({
enterpriseSettings,
backgroundToggled,
show,
isAdmin,
}: {
enterpriseSettings: EnterpriseSettings | null;
backgroundToggled?: boolean;
show?: boolean;
isAdmin?: boolean;
}) {
const router = useRouter();
return (
<button
onClick={isAdmin ? () => router.push("/chat") : () => {}}
className={`max-w-[200px] ${
!show && "mobile:hidden"
} flex items-center gap-x-1`}
>
{enterpriseSettings && enterpriseSettings.application_name ? (
<>
<div className="flex-none my-auto">
<Logo height={24} width={24} />
</div>
<div className="w-full">
<HeaderTitle backgroundToggled={backgroundToggled}>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-left text-subtle">Powered by Onyx</p>
)}
</div>
</>
) : (
<LogoType />
)}
</button>
);
}
export default function FixedLogo({
// Whether the sidebar is toggled or not
backgroundToggled,
}: {
backgroundToggled?: boolean;
}) {
const combinedSettings = useContext(SettingsContext);
const settings = combinedSettings?.settings;
const enterpriseSettings = combinedSettings?.enterpriseSettings;
return (
<>
<Link
href="/chat"
className="fixed cursor-pointer flex z-40 left-4 top-3 h-8"
className="fixed cursor-pointer flex z-40 left-4 top-2 h-8"
>
<LogoComponent
enterpriseSettings={enterpriseSettings!}
backgroundToggled={backgroundToggled}
/>
<div className="max-w-[200px] mobile:hidden flex items-center gap-x-1 my-auto">
<div className="flex-none my-auto">
<Logo height={24} width={24} />
</div>
<div className="w-full">
{enterpriseSettings && enterpriseSettings.application_name ? (
<div>
<HeaderTitle backgroundToggled={backgroundToggled}>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Onyx</p>
)}
</div>
) : (
<HeaderTitle backgroundToggled={backgroundToggled}>
Onyx
</HeaderTitle>
)}
</div>
</div>
</Link>
<div className="mobile:hidden fixed left-4 bottom-4">
<FiSidebar
className={`${
backgroundToggled
? "text-text-mobile-sidebar-toggled"
: "text-text-mobile-sidebar-untoggled"
}`}
/>
<FiSidebar className="text-text-mobile-sidebar" />
</div>
</>
);

View File

@@ -85,8 +85,8 @@ export default function BillingInformationPage() {
{popup}
<h2 className="text-2xl font-bold mb-6 text-gray-800 flex items-center">
{/* <CreditCard className="mr-4 text-gray-600" size={24} /> */}
Subscription Details
<CreditCard className="mr-4 text-gray-600" size={24} />
Billing Information
</h2>
<div className="space-y-4">

View File

@@ -14,6 +14,7 @@ import { buildClientUrl } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import { EnterpriseSettings, GatingType } from "./admin/settings/interfaces";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
import { AppProvider } from "@/components/context/AppProvider";
import { PHProvider } from "./providers";
@@ -22,7 +23,6 @@ import CardSection from "@/components/admin/CardSection";
import { Suspense } from "react";
import PostHogPageView from "./PostHogPageView";
import Script from "next/script";
import { LogoType } from "@/components/logo/Logo";
const inter = Inter({
subsets: ["latin"],
@@ -115,7 +115,8 @@ export default async function RootLayout({
return getPageContent(
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-2 flex items-center max-w-[175px]">
<LogoType />
<HeaderTitle>Onyx</HeaderTitle>
<Logo height={40} width={40} />
</div>
<CardSection className="max-w-md">
@@ -123,8 +124,7 @@ export default async function RootLayout({
<p className="text-text-500">
Your Onyx instance was not configured properly and your settings
could not be loaded. This could be due to an admin configuration
issue, an incomplete setup, or backend services that may not be up
and running yet.
issue or an incomplete setup.
</p>
<p className="mt-4">
If you&apos;re an admin, please check{" "}
@@ -144,7 +144,7 @@ export default async function RootLayout({
community on{" "}
<a
className="text-link"
href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ"
href="https://onyx.app?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
target="_blank"
rel="noopener noreferrer"
>
@@ -160,7 +160,8 @@ export default async function RootLayout({
return getPageContent(
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-2 flex items-center max-w-[175px]">
<LogoType />
<HeaderTitle>Onyx</HeaderTitle>
<Logo height={40} width={40} />
</div>
<CardSection className="w-full max-w-md">
<h1 className="text-2xl font-bold mb-4 text-error">

View File

@@ -1,7 +1,7 @@
"use client";
import { useContext } from "react";
import { SettingsContext } from "../settings/SettingsProvider";
import { SettingsContext } from "./settings/SettingsProvider";
import Image from "next/image";
export function Logo({
@@ -45,10 +45,10 @@ export function Logo({
);
}
export function LogoType() {
export default function LogoType() {
return (
<Image
className="max-h-8 w-full mr-auto "
className="max-h-8 mr-auto "
src="/logotype.png"
alt="Logo"
width={2640}

View File

@@ -0,0 +1,38 @@
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { HeaderTitle } from "./header/HeaderTitle";
import LogoType, { Logo } from "./Logo";
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
export default function LogoTypeContainer({
enterpriseSettings,
}: {
enterpriseSettings: EnterpriseSettings | null;
}) {
const onlyLogo =
!enterpriseSettings ||
!enterpriseSettings.use_custom_logo ||
!enterpriseSettings.application_name;
return (
<div className="flex justify-start items-start w-full gap-x-1 my-auto">
<div className="flex-none w-fit mr-auto my-auto">
{onlyLogo ? <LogoType /> : <Logo height={24} width={24} />}
</div>
{!onlyLogo && (
<div className="w-full">
{enterpriseSettings && enterpriseSettings.application_name ? (
<div>
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Onyx</p>
)}
</div>
) : (
<HeaderTitle>Onyx</HeaderTitle>
)}
</div>
)}
</div>
);
}

View File

@@ -27,9 +27,7 @@ export function MetadataBadge({
size: 12,
className: flexNone ? "flex-none" : "mr-0.5 my-auto",
})}
<p className="max-w-[6rem] text-ellipsis overflow-hidden truncate whitespace-nowrap">
{value}lllaasfasdf
</p>
<div className="my-auto flex">{value}</div>
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More