mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
1 Commits
cloud_debu
...
uf_theming
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
680a093646 |
@@ -1,19 +1,19 @@
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
rev: 24.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
- repo: https://github.com/asottile/reorder_python_imports
|
||||
rev: v3.9.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
args: ['--py311-plus', '--application-directories=backend/']
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
- id: reorder-python-imports
|
||||
args: ["--py311-plus", "--application-directories=backend/"]
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
|
||||
# These settings will remove unused imports with side effects
|
||||
# Note: The repo currently does not and should not have imports with side effects
|
||||
@@ -21,7 +21,13 @@ repos:
|
||||
rev: v2.2.0
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
args:
|
||||
[
|
||||
"--remove-all-unused-imports",
|
||||
"--remove-unused-variables",
|
||||
"--in-place",
|
||||
"--recursive",
|
||||
]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
@@ -31,10 +37,10 @@ repos:
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
|
||||
@@ -1,48 +1,38 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
import os
|
||||
import ssl
|
||||
from typing import Literal
|
||||
import asyncio
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
if USE_IAM_AUTH:
|
||||
if not os.path.exists(SSL_CERT_FILE):
|
||||
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
@@ -59,12 +49,20 @@ def include_object(
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
Excludes specified tables from migrations.
|
||||
"""
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Parses command-line options passed via '-x' in Alembic commands.
|
||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
||||
"""
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -92,12 +90,16 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
"""
|
||||
Executes migrations in the specified schema.
|
||||
"""
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
@@ -115,25 +117,11 @@ def do_run_migrations(
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
# Get IAM authentication token
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
|
||||
# For Alembic / SQLAlchemy in this context, set SSL and password
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""
|
||||
Determines whether to run migrations for a single schema or all schemas,
|
||||
and executes migrations accordingly.
|
||||
"""
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
@@ -141,16 +129,10 @@ async def run_async_migrations() -> None:
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run migrations for all tenant schemas sequentially
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
@@ -180,20 +162,15 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
"""
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run offline migrations for all tenant schemas
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic_offline(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
@@ -230,6 +207,9 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Runs migrations in 'online' mode using an asynchronous engine.
|
||||
"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
"""properly_cascade
|
||||
|
||||
Revision ID: 35e518e0ddf4
|
||||
Revises: 91a0a4d62b14
|
||||
Create Date: 2024-09-20 21:24:04.891018
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e518e0ddf4"
|
||||
down_revision = "91a0a4d62b14"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Update chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add CASCADE delete for tool_call foreign key
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert tool_call foreign key constraint
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""delete workspace
|
||||
|
||||
Revision ID: c0aab6edb6dd
|
||||
Revises: 35e518e0ddf4
|
||||
Create Date: 2024-12-17 14:37:07.660631
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0aab6edb6dd"
|
||||
down_revision = "35e518e0ddf4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = connector_specific_config - 'workspace'
|
||||
WHERE source = 'SLACK'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
import json
|
||||
from sqlalchemy import text
|
||||
from slack_sdk import WebClient
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Fetch all Slack credentials
|
||||
creds_result = conn.execute(
|
||||
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
|
||||
)
|
||||
all_slack_creds = creds_result.fetchall()
|
||||
if not all_slack_creds:
|
||||
return
|
||||
|
||||
for cred_row in all_slack_creds:
|
||||
credential_id, credential_json = cred_row
|
||||
|
||||
credential_json = (
|
||||
credential_json.tobytes().decode("utf-8")
|
||||
if isinstance(credential_json, memoryview)
|
||||
else credential_json.decode("utf-8")
|
||||
)
|
||||
credential_data = json.loads(credential_json)
|
||||
slack_bot_token = credential_data.get("slack_bot_token")
|
||||
if not slack_bot_token:
|
||||
print(
|
||||
f"No slack_bot_token found for credential {credential_id}. "
|
||||
"Your Slack connector will not function until you upgrade and provide a valid token."
|
||||
)
|
||||
continue
|
||||
|
||||
client = WebClient(token=slack_bot_token)
|
||||
try:
|
||||
auth_response = client.auth_test()
|
||||
workspace = auth_response["url"].split("//")[1].split(".")[0]
|
||||
|
||||
# Update only the connectors linked to this credential
|
||||
# (and which are Slack connectors).
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE connector AS c
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{{workspace}}',
|
||||
to_jsonb('{workspace}'::text)
|
||||
)
|
||||
FROM connector_credential_pair AS ccp
|
||||
WHERE ccp.connector_id = c.id
|
||||
AND c.source = 'SLACK'
|
||||
AND ccp.credential_id = {credential_id}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
|
||||
)
|
||||
print("This connector will no longer work until you upgrade.")
|
||||
continue
|
||||
@@ -53,5 +53,3 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
@@ -3,15 +3,12 @@ import logging
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
@@ -50,16 +47,13 @@ from shared_configs.enums import EmbeddingProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
async def get_or_create_tenant_id(
|
||||
email: str, referral_source: str | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
@@ -287,36 +281,3 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
logger.info(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
||||
|
||||
async def submit_to_hubspot(
|
||||
email: str, referral_source: str | None, request: Request
|
||||
) -> None:
|
||||
if not HUBSPOT_TRACKING_URL:
|
||||
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
|
||||
return
|
||||
|
||||
# HubSpot tracking cookie
|
||||
hubspot_cookie = request.cookies.get("hubspotutk")
|
||||
|
||||
# IP address
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
data = {
|
||||
"fields": [
|
||||
{"name": "email", "value": email},
|
||||
{"name": "referral_source", "value": referral_source or ""},
|
||||
],
|
||||
"context": {
|
||||
"hutk": hubspot_cookie,
|
||||
"ipAddress": ip_address,
|
||||
"pageUri": str(request.url),
|
||||
"pageName": "User Registration",
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
@@ -1,38 +1,14 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
posthog = Posthog(project_api_key=POSTHOG_API_KEY, host=POSTHOG_HOST)
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
distinct_id: str,
|
||||
event: str,
|
||||
properties: dict | None = None,
|
||||
) -> None:
|
||||
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
|
||||
print("API KEY", POSTHOG_API_KEY)
|
||||
print("HOST", POSTHOG_HOST)
|
||||
try:
|
||||
print(type(distinct_id))
|
||||
print(type(event))
|
||||
print(type(properties))
|
||||
response = posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
print(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing Posthog event: {e}")
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
|
||||
@@ -27,8 +27,8 @@ from shared_configs.configs import SENTRY_DSN
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
|
||||
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
|
||||
HF_CACHE_PATH = Path("/root/.cache/huggingface/")
|
||||
TEMP_HF_CACHE_PATH = Path("/root/.cache/temp_huggingface/")
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@@ -229,26 +228,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
)
|
||||
|
||||
user_count: int | None = None
|
||||
referral_source = (
|
||||
request.cookies.get("referral_source", None)
|
||||
if request is not None
|
||||
else None
|
||||
)
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -291,6 +282,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
# Blocking but this should be very quick
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if not user_count:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.USER_SIGNED_UP,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.MULTIPLE_USERS,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
@@ -336,18 +346,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> User:
|
||||
referral_source = (
|
||||
getattr(request.state, "referral_source", None) if request else None
|
||||
)
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
@@ -409,7 +418,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -463,39 +471,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
request=request,
|
||||
)
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
user_count = await get_user_count()
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if user_count == 1:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.USER_SIGNED_UP,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.MULTIPLE_USERS,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
@@ -527,7 +502,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
@@ -588,7 +563,7 @@ class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
|
||||
@@ -3,6 +3,7 @@ import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
@@ -22,7 +23,6 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
@@ -262,8 +262,7 @@ def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -153,6 +154,10 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -61,14 +61,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -62,14 +62,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -60,15 +60,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -84,14 +84,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from redis import Redis
|
||||
@@ -25,25 +23,3 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
|
||||
|
||||
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to find a task for a particular queue in redis.
|
||||
It is priority aware and knows how to look through the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic.
|
||||
|
||||
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
|
||||
|
||||
Returns true if the id is in the queue, False if not.
|
||||
"""
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
if task_dict.get("headers", {}).get("id") == task_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -4,80 +4,55 @@ from typing import Any
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -17,7 +15,6 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
@@ -165,19 +162,11 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -282,7 +271,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
search_settings_instance,
|
||||
reindex,
|
||||
db_session,
|
||||
redis_client,
|
||||
r,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
@@ -297,9 +286,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -317,25 +304,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
|
||||
# turning off for the moment to see if it helps cloud stability
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# task_logger.info("Validating indexing fences...")
|
||||
# validate_indexing_fences(
|
||||
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -352,190 +320,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks: set[str] = set()
|
||||
active_indexing_tasks: set[str] = set()
|
||||
indexing_worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = celery_app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if not workers:
|
||||
raise ValueError("No workers found!")
|
||||
|
||||
for worker_name in list(workers.keys()):
|
||||
if "indexing" in worker_name:
|
||||
indexing_worker_names.append(worker_name)
|
||||
|
||||
if len(indexing_worker_names) == 0:
|
||||
raise ValueError("No indexing workers found!")
|
||||
|
||||
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
|
||||
|
||||
# NOTE: each dict entry is a map of worker name to a list of tasks
|
||||
# we want sets for reserved task and active task id's to optimize
|
||||
# subsequent validation lookups
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
|
||||
if reserved_tasks is None:
|
||||
raise ValueError("inspect_indexing.reserved() returned None!")
|
||||
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_indexing_tasks.add(task["id"])
|
||||
|
||||
# get the list of active tasks
|
||||
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
|
||||
if active_tasks is None:
|
||||
raise ValueError("inspect_indexing.active() returned None!")
|
||||
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_indexing_tasks.add(task["id"])
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
active_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
active_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. Active signal is renewed with a 5 minute TTL
|
||||
1.1 When the fence is created
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved or active list for a worker
|
||||
2. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in active_tasks:
|
||||
# the celery task is active (aka currently executing)
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
|
||||
# due to gaps in our ability to check states during transitions
|
||||
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed."
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
@@ -682,7 +469,6 @@ def try_creating_indexing_task(
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
@@ -716,8 +502,6 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
@@ -858,7 +642,7 @@ def connector_indexing_proxy_task(
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
@@ -1088,7 +872,6 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -90,11 +89,10 @@ logger = setup_logger()
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -163,10 +161,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
@@ -736,7 +730,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -831,8 +824,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
@@ -145,7 +144,6 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -176,9 +174,6 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
@@ -488,21 +483,6 @@ SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
# allow for custom error messages for different errors returned by litellm
|
||||
# for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"}
|
||||
# to make it so that if an LLM call returns an error containing "Violated content safety policy"
|
||||
# the end user will see "EVIL REQUEST!!!" instead of the default error message.
|
||||
_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get(
|
||||
"LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", ""
|
||||
)
|
||||
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None
|
||||
try:
|
||||
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast(
|
||||
dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
|
||||
@@ -63,10 +63,6 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
|
||||
@@ -49,7 +49,6 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
SSL_CERT_FILE = "bundle.pem"
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
@@ -275,10 +274,6 @@ class OnyxRedisLocks:
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
|
||||
|
||||
class OnyxCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
HIGH = auto()
|
||||
|
||||
@@ -316,23 +316,6 @@ def update_chat_session(
|
||||
return chat_session
|
||||
|
||||
|
||||
def delete_all_chat_sessions_for_user(
|
||||
user: User | None, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS
|
||||
) -> None:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
query = db_session.query(ChatSession).filter(
|
||||
ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)
|
||||
)
|
||||
|
||||
if hard_delete:
|
||||
query.delete(synchronize_session=False)
|
||||
else:
|
||||
query.update({ChatSession.deleted: True}, synchronize_session=False)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_chat_session(
|
||||
user_id: UUID | None,
|
||||
chat_session_id: UUID,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -12,8 +10,6 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
@@ -27,7 +23,6 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -42,7 +37,6 @@ from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -55,87 +49,28 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
|
||||
# Global so we don't create more than one engine per process
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
|
||||
def get_iam_auth_token(
|
||||
host: str, port: str, user: str, region: str = "us-east-2"
|
||||
) -> str:
|
||||
"""
|
||||
Generate an IAM authentication token using boto3.
|
||||
"""
|
||||
client = boto3.client("rds", region_name=region)
|
||||
token = client.generate_db_auth_token(
|
||||
DBHostname=host, Port=int(port), DBUsername=user
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def configure_psycopg2_iam_auth(
|
||||
cparams: dict[str, Any], host: str, port: str, user: str, region: str
|
||||
) -> None:
|
||||
"""
|
||||
Configure cparams for psycopg2 with IAM token and SSL.
|
||||
"""
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
cparams["password"] = token
|
||||
cparams["sslmode"] = "require"
|
||||
cparams["sslrootcert"] = SSL_CERT_FILE
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
user: str = POSTGRES_USER,
|
||||
password: str = POSTGRES_PASSWORD,
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
use_iam: bool = USE_IAM_AUTH,
|
||||
region: str = "us-west-2",
|
||||
) -> str:
|
||||
if use_iam:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
||||
else:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
# For asyncpg, do not include application_name in the connection string
|
||||
if app_name and db_api != "asyncpg":
|
||||
if "?" in base_conn_str:
|
||||
return f"{base_conn_str}&application_name={app_name}"
|
||||
else:
|
||||
return f"{base_conn_str}?application_name={app_name}"
|
||||
return base_conn_str
|
||||
|
||||
|
||||
if LOG_POSTGRES_LATENCY:
|
||||
|
||||
# Function to log before query execution
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
def before_cursor_execute( # type: ignore
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
conn.info["query_start_time"] = time.time()
|
||||
|
||||
# Function to log after query execution
|
||||
@event.listens_for(Engine, "after_cursor_execute")
|
||||
def after_cursor_execute( # type: ignore
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
total_time = time.time() - conn.info["query_start_time"]
|
||||
# don't spam TOO hard
|
||||
if total_time > 0.1:
|
||||
logger.debug(
|
||||
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
|
||||
@@ -143,6 +78,7 @@ if LOG_POSTGRES_LATENCY:
|
||||
|
||||
|
||||
if LOG_POSTGRES_CONN_COUNTS:
|
||||
# Global counter for connection checkouts and checkins
|
||||
checkout_count = 0
|
||||
checkin_count = 0
|
||||
|
||||
@@ -169,13 +105,21 @@ if LOG_POSTGRES_CONN_COUNTS:
|
||||
logger.debug(f"Total connection checkins: {checkin_count}")
|
||||
|
||||
|
||||
"""END DEBUGGING LOGGING"""
|
||||
|
||||
|
||||
def get_db_current_time(db_session: Session) -> datetime:
|
||||
"""Get the current time from Postgres representing the start of the transaction
|
||||
Within the same transaction this value will not update
|
||||
This datetime object returned should be timezone aware, default Postgres timezone is UTC
|
||||
"""
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result is None:
|
||||
raise ValueError("Database did not return a time")
|
||||
return result
|
||||
|
||||
|
||||
# Regular expression to validate schema names to prevent SQL injection
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
@@ -184,9 +128,16 @@ def is_valid_schema_name(name: str) -> bool:
|
||||
|
||||
|
||||
class SqlEngine:
|
||||
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
||||
Will eventually subsume most of the standalone functions in this file.
|
||||
Sync only for now.
|
||||
"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
# Default parameters for engine creation
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 20,
|
||||
"max_overflow": 5,
|
||||
@@ -194,27 +145,33 @@ class SqlEngine:
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
"""Private helper method to create and return an Engine."""
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
||||
)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
engine = create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
return engine
|
||||
return create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
"""Allow the caller to init the engine with extra params. Different clients
|
||||
such as the API server and different Celery workers and tasks
|
||||
need different settings.
|
||||
"""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!
|
||||
"""
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
@@ -223,10 +180,12 @@ class SqlEngine:
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
"""Class method to set the app name."""
|
||||
cls._app_name = app_name
|
||||
|
||||
@classmethod
|
||||
def get_app_name(cls) -> str:
|
||||
"""Class method to get current app name."""
|
||||
if not cls._app_name:
|
||||
return ""
|
||||
return cls._app_name
|
||||
@@ -258,71 +217,56 @@ def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
user: str = POSTGRES_USER,
|
||||
password: str = POSTGRES_PASSWORD,
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
async def get_async_connection() -> Any:
|
||||
"""
|
||||
Custom connection function for async engine when using IAM auth.
|
||||
"""
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
db = POSTGRES_DB
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
|
||||
# asyncpg requires 'ssl="require"' if SSL needed
|
||||
return await asyncpg.connect(
|
||||
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
|
||||
)
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
app_name = SqlEngine.get_app_name() + "_async"
|
||||
connection_string = build_connection_string(
|
||||
db_api=ASYNC_DB_API,
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
connect_args: dict[str, Any] = {}
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
# Underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
connect_args=connect_args,
|
||||
connect_args={
|
||||
"server_settings": {
|
||||
"application_name": SqlEngine.get_app_name() + "_async"
|
||||
}
|
||||
},
|
||||
# async engine is only used by API server, so we can use those values
|
||||
# here as well
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
|
||||
def provide_iam_token_async(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
# For async engine using asyncpg, we still need to set the IAM token here.
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
# Dependency to get the current tenant ID
|
||||
# If no token is present, uses the default schema for this use case
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -331,6 +275,7 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
return current_value
|
||||
|
||||
try:
|
||||
@@ -344,6 +289,7 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
@@ -370,6 +316,7 @@ async def get_async_session_with_tenant(
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
# Set the search_path to the tenant's schema
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
@@ -379,6 +326,8 @@ async def get_async_session_with_tenant(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
# You can choose to re-raise the exception or handle it
|
||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
||||
raise
|
||||
else:
|
||||
yield session
|
||||
@@ -386,6 +335,9 @@ async def get_async_session_with_tenant(
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session using the current tenant ID from the context variable.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
@@ -397,6 +349,7 @@ def get_session_with_tenant(
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
@@ -404,20 +357,27 @@ def get_session_with_tenant(
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Store the previous tenant ID
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
try:
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
@@ -430,17 +390,21 @@ def get_session_with_tenant(
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
finally:
|
||||
# Restore the previous tenant ID
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
|
||||
|
||||
@@ -460,9 +424,12 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
@@ -470,17 +437,20 @@ def get_session() -> Generator[Session, None, None]:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield async_session
|
||||
|
||||
@@ -491,6 +461,7 @@ def get_session_context_manager() -> ContextManager[Session]:
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
"""Get a session factory."""
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
@@ -518,13 +489,3 @@ async def warm_up_connections(
|
||||
await async_conn.execute(text("SELECT 1"))
|
||||
for async_conn in async_connections:
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
region = os.getenv("AWS_REGION", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
@@ -5,8 +5,6 @@ from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@@ -1010,7 +1008,7 @@ class ChatSession(Base):
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
)
|
||||
messages: Mapped[list["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
|
||||
"ChatMessage", back_populates="chat_session"
|
||||
)
|
||||
persona: Mapped["Persona"] = relationship("Persona")
|
||||
|
||||
@@ -1078,8 +1076,6 @@ class ChatMessage(Base):
|
||||
"SearchDoc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
@@ -1348,11 +1344,6 @@ class StarterMessage(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class StarterMessageModel(BaseModel):
|
||||
name: str
|
||||
message: str
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
|
||||
|
||||
@@ -543,10 +543,6 @@ def upsert_persona(
|
||||
if tools is not None:
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
# We should only update display priority if it is not already set
|
||||
if existing_persona.display_priority is None:
|
||||
existing_persona.display_priority = display_priority
|
||||
|
||||
persona = existing_persona
|
||||
|
||||
else:
|
||||
|
||||
@@ -369,19 +369,6 @@ class AdminCapable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""Class must implement random document retrieval capability"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIndex(
|
||||
Verifiable,
|
||||
Indexable,
|
||||
@@ -389,7 +376,6 @@ class BaseIndex(
|
||||
Deletable,
|
||||
AdminCapable,
|
||||
IdRetrievalCapable,
|
||||
RandomCapable,
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -218,10 +218,4 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: bm25(content) + (5 * bm25(title))
|
||||
}
|
||||
}
|
||||
|
||||
rank-profile random_ {
|
||||
first-phase {
|
||||
expression: random.match
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
@@ -535,7 +534,7 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
with get_vespa_http_client() as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
@@ -546,12 +545,8 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
while True:
|
||||
try:
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}"
|
||||
)
|
||||
logger.debug(f'update_single PUT on URL "{vespa_url}"')
|
||||
resp = http_client.put(
|
||||
vespa_url,
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
|
||||
params=params,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
@@ -623,7 +618,7 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
with get_vespa_http_client() as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
@@ -634,12 +629,8 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
while True:
|
||||
try:
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}"
|
||||
)
|
||||
logger.debug(f'delete_single DELETE on URL "{vespa_url}"')
|
||||
resp = http_client.delete(
|
||||
vespa_url,
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -912,32 +903,6 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
logger.info("Batch deletion completed")
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters using Vespa's random ranking
|
||||
|
||||
This method is currently used for random chunk retrieval in the context of
|
||||
assistant starter message creation (passed as sample context for usage by the assistant).
|
||||
"""
|
||||
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
|
||||
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
|
||||
|
||||
random_seed = random.randint(0, 1000000)
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"hits": num_to_retrieve,
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
"ranking.profile": "random_",
|
||||
"ranking.properties.random.seed": random_seed,
|
||||
}
|
||||
|
||||
return query_vespa(params)
|
||||
|
||||
|
||||
class _VespaDeleteRequest:
|
||||
def __init__(self, document_id: str, index_name: str) -> None:
|
||||
|
||||
@@ -55,9 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = False
|
||||
) -> httpx.Client:
|
||||
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
|
||||
"""
|
||||
Configure and return an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -69,5 +67,5 @@ def get_vespa_http_client(
|
||||
else None,
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=http2,
|
||||
http2=True,
|
||||
)
|
||||
|
||||
@@ -19,12 +19,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
filters: IndexFilters,
|
||||
*,
|
||||
include_hidden: bool = False,
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
@@ -83,9 +78,6 @@ def build_vespa_filters(
|
||||
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5] # We remove the trailing " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
|
||||
@@ -453,9 +453,7 @@ class DefaultMultiLLM(LLM):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if (
|
||||
DISABLE_LITELLM_STREAMING or self.config.model_name == "o1-2024-12-17"
|
||||
): # TODO: remove once litellm supports streaming
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
|
||||
@@ -28,7 +28,6 @@ from litellm.exceptions import RateLimitError # type: ignore
|
||||
from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -46,19 +45,10 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception,
|
||||
llm: LLM,
|
||||
fallback_to_error_msg: bool = False,
|
||||
custom_error_msg_mappings: dict[str, str]
|
||||
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
|
||||
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
|
||||
) -> str:
|
||||
error_msg = str(e)
|
||||
|
||||
if custom_error_msg_mappings:
|
||||
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
||||
if error_msg_pattern in error_msg:
|
||||
return custom_error_msg
|
||||
|
||||
if isinstance(e, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
elif isinstance(e, AuthenticationError):
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
PERSONA_CATEGORY_GENERATION_PROMPT = """
|
||||
Based on the assistant's name, description, and instructions, generate a list of {num_categories}
|
||||
**unique and diverse** categories that represent different types of starter messages a user
|
||||
might send to initiate a conversation with this chatbot assistant.
|
||||
|
||||
**Ensure that the categories are varied and cover a wide range of topics related to the assistant's capabilities.**
|
||||
|
||||
Provide the categories as a JSON array of strings **without any code fences or additional text**.
|
||||
|
||||
**Context about the assistant:**
|
||||
- **Name**: {name}
|
||||
- **Description**: {description}
|
||||
- **Instructions**: {instructions}
|
||||
""".strip()
|
||||
|
||||
PERSONA_STARTER_MESSAGE_CREATION_PROMPT = """
|
||||
Create a starter message that a **user** might send to initiate a conversation with a chatbot assistant.
|
||||
|
||||
**Category**: {category}
|
||||
|
||||
Your response should include two parts:
|
||||
|
||||
1. **Title**: A short, engaging title that reflects the user's intent
|
||||
(e.g., 'Need Travel Advice', 'Question About Coding', 'Looking for Book Recommendations').
|
||||
|
||||
2. **Message**: The actual message that the user would send to the assistant.
|
||||
This should be natural, engaging, and encourage a helpful response from the assistant.
|
||||
**Avoid overly specific details; keep the message general and broadly applicable.**
|
||||
|
||||
For example:
|
||||
- Instead of "I've just adopted a 6-month-old Labrador puppy who's pulling on the leash,"
|
||||
write "I'm having trouble training my new puppy to walk nicely on a leash."
|
||||
|
||||
Ensure each part is clearly labeled and separated as shown above.
|
||||
Do not provide any additional text or explanation and be extremely concise
|
||||
|
||||
**Context about the assistant:**
|
||||
- **Name**: {name}
|
||||
- **Description**: {description}
|
||||
- **Instructions**: {instructions}
|
||||
""".strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(PERSONA_CATEGORY_GENERATION_PROMPT)
|
||||
print(PERSONA_STARTER_MESSAGE_CREATION_PROMPT)
|
||||
@@ -31,10 +31,6 @@ class RedisConnectorIndex:
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's difficult to prevent
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
@@ -58,7 +54,6 @@ class RedisConnectorIndex:
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -112,26 +107,6 @@ class RedisConnectorIndex:
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=300)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
if self.redis.exists(self.generator_lock_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
@@ -163,7 +138,6 @@ class RedisConnectorIndex:
|
||||
return status
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from litellm import get_supported_openai_params
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPTS
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.db.document_set import get_document_sets_by_ids
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
|
||||
from onyx.prompts.starter_messages import PERSONA_STARTER_MESSAGE_CREATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_random_chunks_from_doc_sets(
|
||||
doc_sets: List[str], db_session: Session, user: User | None = None
|
||||
) -> List[InferenceChunk]:
|
||||
"""
|
||||
Retrieves random chunks from the specified document sets.
|
||||
"""
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
|
||||
|
||||
acl_filters = build_access_filters_for_user(user, db_session)
|
||||
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)
|
||||
|
||||
chunks = document_index.random_retrieval(
|
||||
filters=filters, num_to_retrieve=NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
)
|
||||
return cleanup_chunks(chunks)
|
||||
|
||||
|
||||
def parse_categories(content: str) -> List[str]:
|
||||
"""
|
||||
Parses the JSON array of categories from the LLM response.
|
||||
"""
|
||||
# Clean the response to remove code fences and extra whitespace
|
||||
content = content.strip().strip("```").strip()
|
||||
if content.startswith("json"):
|
||||
content = content[4:].strip()
|
||||
|
||||
try:
|
||||
categories = json.loads(content)
|
||||
if not isinstance(categories, list):
|
||||
logger.error("Categories are not a list.")
|
||||
return []
|
||||
return categories
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse categories: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def generate_start_message_prompts(
|
||||
name: str,
|
||||
description: str,
|
||||
instructions: str,
|
||||
categories: List[str],
|
||||
chunk_contents: str,
|
||||
supports_structured_output: bool,
|
||||
fast_llm: Any,
|
||||
) -> List[FunctionCall]:
|
||||
"""
|
||||
Generates the list of FunctionCall objects for starter message generation.
|
||||
"""
|
||||
functions = []
|
||||
for category in categories:
|
||||
# Create a prompt specific to the category
|
||||
start_message_generation_prompt = (
|
||||
PERSONA_STARTER_MESSAGE_CREATION_PROMPT.format(
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
if chunk_contents:
|
||||
start_message_generation_prompt += (
|
||||
"\n\nExample content this assistant has access to:\n"
|
||||
"'''\n"
|
||||
f"{chunk_contents}"
|
||||
"\n'''"
|
||||
)
|
||||
|
||||
if supports_structured_output:
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt, None, None, StarterMessage),
|
||||
)
|
||||
)
|
||||
else:
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt,),
|
||||
)
|
||||
)
|
||||
return functions
|
||||
|
||||
|
||||
def parse_unstructured_output(output: str) -> Dict[str, str]:
|
||||
"""
|
||||
Parses the assistant's unstructured output into a dictionary with keys:
|
||||
- 'name' (Title)
|
||||
- 'message' (Message)
|
||||
"""
|
||||
|
||||
# Debug output
|
||||
logger.debug(f"LLM Output for starter message creation: {output}")
|
||||
|
||||
# Patterns to match
|
||||
title_pattern = r"(?i)^\**Title\**\s*:\s*(.+)"
|
||||
message_pattern = r"(?i)^\**Message\**\s*:\s*(.+)"
|
||||
|
||||
# Initialize the response dictionary
|
||||
response_dict = {}
|
||||
|
||||
# Split the output into lines
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# Variables to keep track of the current key being processed
|
||||
current_key = None
|
||||
current_value_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Check for title
|
||||
title_match = re.match(title_pattern, line.strip())
|
||||
if title_match:
|
||||
# Save previous key-value pair if any
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
current_value_lines = []
|
||||
current_key = "name"
|
||||
current_value_lines.append(title_match.group(1).strip())
|
||||
continue
|
||||
|
||||
# Check for message
|
||||
message_match = re.match(message_pattern, line.strip())
|
||||
if message_match:
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
current_value_lines = []
|
||||
current_key = "message"
|
||||
current_value_lines.append(message_match.group(1).strip())
|
||||
continue
|
||||
|
||||
# If the line doesn't match a new key, append it to the current value
|
||||
if current_key:
|
||||
current_value_lines.append(line.strip())
|
||||
|
||||
# Add the last key-value pair
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
|
||||
# Validate that the necessary keys are present
|
||||
if not all(k in response_dict for k in ["name", "message"]):
|
||||
raise ValueError("Failed to parse the assistant's response.")
|
||||
|
||||
return response_dict
|
||||
|
||||
|
||||
def generate_starter_messages(
|
||||
name: str,
|
||||
description: str,
|
||||
instructions: str,
|
||||
document_set_ids: List[int],
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
) -> List[StarterMessage]:
|
||||
"""
|
||||
Generates starter messages by first obtaining categories and then generating messages for each category.
|
||||
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
|
||||
"""
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
|
||||
supports_structured_output = (
|
||||
isinstance(params, list) and "response_format" in params
|
||||
)
|
||||
|
||||
# Generate categories
|
||||
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
num_categories=NUM_PERSONA_PROMPTS,
|
||||
)
|
||||
|
||||
category_response = fast_llm.invoke(category_generation_prompt)
|
||||
categories = parse_categories(cast(str, category_response.content))
|
||||
|
||||
if not categories:
|
||||
logger.error("No categories were generated.")
|
||||
return []
|
||||
|
||||
# Fetch example content if document sets are provided
|
||||
if document_set_ids:
|
||||
document_sets = get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
chunks = get_random_chunks_from_doc_sets(
|
||||
doc_sets=[doc_set.name for doc_set in document_sets],
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Add example content context
|
||||
chunk_contents = "\n".join(chunk.content.strip() for chunk in chunks)
|
||||
else:
|
||||
chunk_contents = ""
|
||||
|
||||
# Generate prompts for starter messages
|
||||
functions = generate_start_message_prompts(
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
categories,
|
||||
chunk_contents,
|
||||
supports_structured_output,
|
||||
fast_llm,
|
||||
)
|
||||
|
||||
# Run LLM calls in parallel
|
||||
if not functions:
|
||||
logger.error("No functions to execute for starter message generation.")
|
||||
return []
|
||||
|
||||
results = run_functions_in_parallel(function_calls=functions)
|
||||
prompts = []
|
||||
|
||||
for response in results.values():
|
||||
try:
|
||||
if supports_structured_output:
|
||||
response_dict = json.loads(response.content)
|
||||
else:
|
||||
response_dict = parse_unstructured_output(response.content)
|
||||
starter_message = StarterMessage(
|
||||
name=response_dict["name"],
|
||||
message=response_dict["message"],
|
||||
)
|
||||
prompts.append(starter_message)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse starter message: {e}")
|
||||
continue
|
||||
|
||||
return prompts
|
||||
@@ -48,7 +48,6 @@ def load_personas_from_yaml(
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
@@ -128,7 +127,6 @@ def load_personas_from_yaml(
|
||||
display_priority=(
|
||||
existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
and persona.get("display_priority") is None
|
||||
else persona.get("display_priority")
|
||||
),
|
||||
is_visible=(
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import create_assistant_category
|
||||
@@ -37,11 +36,7 @@ from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.secondary_llm_flows.starter_message_creation import (
|
||||
generate_starter_messages,
|
||||
)
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
||||
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
||||
from onyx.server.features.persona.models import PersonaCategoryCreate
|
||||
from onyx.server.features.persona.models import PersonaCategoryResponse
|
||||
@@ -382,26 +377,3 @@ def build_final_template_prompt(
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@basic_router.post("/assistant-prompt-refresh")
|
||||
def build_assistant_prompts(
|
||||
generate_persona_prompt_request: GenerateStarterMessageRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> list[StarterMessage]:
|
||||
try:
|
||||
logger.info(
|
||||
"Generating starter messages for user: %s", user.id if user else "Anonymous"
|
||||
)
|
||||
return generate_starter_messages(
|
||||
name=generate_persona_prompt_request.name,
|
||||
description=generate_persona_prompt_request.description,
|
||||
instructions=generate_persona_prompt_request.instructions,
|
||||
document_set_ids=generate_persona_prompt_request.document_set_ids,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate starter messages")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -17,14 +17,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# More minimal request for generating a persona prompt
|
||||
class GenerateStarterMessageRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
instructions: str
|
||||
document_set_ids: list[int]
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
@@ -35,7 +35,6 @@ from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
from onyx.db.chat import add_chats_to_session_from_slack_thread
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import delete_all_chat_sessions_for_user
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import duplicate_chat_session_for_user_from_slack
|
||||
from onyx.db.chat import get_chat_message
|
||||
@@ -281,17 +280,6 @@ def patch_chat_session(
|
||||
return None
|
||||
|
||||
|
||||
@router.delete("/delete-all-chat-sessions")
|
||||
def delete_all_chat_sessions(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
delete_all_chat_sessions_for_user(user=user, db_session=db_session)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/delete-chat-session/{session_id}")
|
||||
def delete_chat_session_by_id(
|
||||
session_id: UUID,
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.chat.models import RetrievalDocs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
@@ -152,10 +151,6 @@ class ChatSessionUpdateRequest(BaseModel):
|
||||
sharing_status: ChatSessionSharedStatus
|
||||
|
||||
|
||||
class DeleteAllSessionsRequest(BaseModel):
|
||||
session_type: SessionType
|
||||
|
||||
|
||||
class RenameChatSessionResponse(BaseModel):
|
||||
new_name: str # This is only really useful if the name is generated
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.utils.variable_functionality import (
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
_CACHED_INSTANCE_DOMAIN: str | None = None
|
||||
@@ -118,12 +117,9 @@ def mt_cloud_telemetry(
|
||||
event: MilestoneRecordType,
|
||||
properties: dict | None = None,
|
||||
) -> None:
|
||||
print(f"mt_cloud_telemetry {distinct_id} {event} {properties}")
|
||||
if not MULTI_TENANT:
|
||||
print("mt_cloud_telemetry not MULTI_TENANT")
|
||||
return
|
||||
|
||||
print("mt_cloud_telemetry MULTI_TENANT")
|
||||
# MIT version should not need to include any Posthog code
|
||||
# This is only for Onyx MT Cloud, this code should also never be hit, no reason for any orgs to
|
||||
# be running the Multi Tenant version of Onyx.
|
||||
@@ -141,11 +137,8 @@ def create_milestone_and_report(
|
||||
properties: dict | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
print(f"create_milestone_and_report {user} {event_type} {db_session}")
|
||||
_, is_new = create_milestone_if_not_exists(user, event_type, db_session)
|
||||
print(f"create_milestone_and_report {is_new}")
|
||||
if is_new:
|
||||
print("create_milestone_and_report is_new")
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=distinct_id,
|
||||
event=event_type,
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.55.4
|
||||
litellm==1.54.1
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
@@ -12,5 +12,5 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.55.4
|
||||
litellm==1.54.1
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -92,7 +92,6 @@ services:
|
||||
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
|
||||
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
|
||||
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
|
||||
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
@@ -104,13 +103,6 @@ services:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
# Seeding configuration
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -231,13 +223,6 @@ services:
|
||||
|
||||
# Enterprise Edition stuff
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
|
||||
@@ -84,7 +84,6 @@ services:
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
|
||||
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
|
||||
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
|
||||
|
||||
# Chat Configs
|
||||
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
|
||||
@@ -92,13 +91,6 @@ services:
|
||||
# Enterprise Edition only
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -200,13 +192,6 @@ services:
|
||||
# Enterprise Edition only
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
|
||||
@@ -22,13 +22,6 @@ services:
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -59,13 +52,6 @@ services:
|
||||
- REDIS_HOST=cache
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
|
||||
@@ -23,13 +23,6 @@ services:
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -64,13 +57,6 @@ services:
|
||||
- REDIS_HOST=cache
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -237,7 +223,7 @@ services:
|
||||
volumes:
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
logging::wq
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
@@ -259,6 +245,3 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -60,12 +60,3 @@ spec:
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
# Uncomment if you are using IAM auth for Postgres
|
||||
# volumeMounts:
|
||||
# - name: bundle-pem
|
||||
# mountPath: "/app/certs"
|
||||
# readOnly: true
|
||||
# volumes:
|
||||
# - name: bundle-pem
|
||||
# secret:
|
||||
# secretName: bundle-pem-secret
|
||||
|
||||
@@ -43,7 +43,6 @@ spec:
|
||||
# - name: my-ca-cert-volume
|
||||
# mountPath: /etc/ssl/certs/custom-ca.crt
|
||||
# subPath: my-ca.crt
|
||||
|
||||
# Optional volume for CA certificate
|
||||
# volumes:
|
||||
# - name: my-cas-cert-volume
|
||||
@@ -52,13 +51,3 @@ spec:
|
||||
# items:
|
||||
# - key: my-ca.crt
|
||||
# path: my-ca.crt
|
||||
|
||||
# Uncomment if you are using IAM auth for Postgres
|
||||
# volumeMounts:
|
||||
# - name: bundle-pem
|
||||
# mountPath: "/app/certs"
|
||||
# readOnly: true
|
||||
# volumes:
|
||||
# - name: bundle-pem
|
||||
# secret:
|
||||
# secretName: bundle-pem-secret
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 7.0 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 340 KiB |
@@ -1,6 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="1.33325" y="1.3335" width="6.33333" height="6.33333" fill="#F25022"/>
|
||||
<rect x="8.33325" y="1.3335" width="6.33333" height="6.33333" fill="#80BA01"/>
|
||||
<rect x="8.33325" y="8.3335" width="6.33333" height="6.33333" fill="#FFB902"/>
|
||||
<rect x="1.33325" y="8.3335" width="6.33333" height="6.33333" fill="#02A4EF"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 425 B |
@@ -1 +0,0 @@
|
||||
<svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" fill-rule="evenodd" clip-rule="evenodd" stroke-linejoin="round" stroke-miterlimit="2"><path d="M189.08 303.228H94.587l.044-94.446h94.497l-.048 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M283.528 397.674h-94.493l.044-94.446h94.496l-.047 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M283.575 303.228H189.08l.046-94.446h94.496l-.047 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M378.07 303.228h-94.495l.044-94.446h94.498l-.047 94.446zM189.128 208.779H94.633l.044-94.448h94.498l-.047 94.448zM378.115 208.779h-94.494l.045-94.448h94.496l-.047 94.448zM94.587 303.227H.093l.044-96.017h94.496l-.046 96.017z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M94.633 208.779H.138l.046-94.448H94.68l-.047 94.448z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M94.68 115.902H.185L.23 19.885h94.498l-.047 96.017zM472.657 114.331h-94.495l.044-94.446h94.497l-.046 94.446zM94.54 399.244H.046l.044-97.588h94.497l-.047 97.588z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M94.495 492.123H0l.044-94.446H94.54l-.045 94.446zM472.563 303.228H378.07l.044-94.446h94.496l-.047 94.446zM472.61 208.779h-94.495l.044-94.448h94.498l-.047 94.448z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M472.517 397.674h-94.494l.044-94.446h94.497l-.047 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M472.47 492.121h-94.493l.044-96.017h94.496l-.047 96.017z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M228.375 303.22h-96.061l.046-94.446h96.067l-.052 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M322.827 397.666h-94.495l.044-96.018h94.498l-.047 96.018z" fill="#ff4900" fill-rule="nonzero"/><path d="M324.444 303.22h-97.636l.046-94.446h97.638l-.048 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M418.938 303.22h-96.064l.045-94.446h96.066l-.047 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M228.423 208.77H132.36l.045-94.445h96.066l-.05 94.446zM418.985 208.77H322.92l.044-94.445h96.069l-.048 94.446z" fill="#ffa300" fill-rule="nonzero"/><path d="M133.883 304.79H39.392l.044-96.017h94.496l-.049 96.017z" fill="#ff7000" fill-rule="nonzero"/><path d="M133.929 208.77H39.437l.044-95.445h94.496l-.048 95.445z" fill="#ffa300" fill-rule="nonzero"/><path d="M133.976 114.325H39.484l.044-94.448h94.497l-.05 94.448zM511.954 115.325h-94.493l.044-95.448h94.497l-.048 95.448z" fill="#ffce00" fill-rule="nonzero"/><path d="M133.836 399.667H39.345l.044-96.447h94.496l-.049 96.447z" fill="#ff4900" fill-rule="nonzero"/><path d="M133.79 492.117H39.3l.044-94.448h94.496l-.049 94.448z" fill="#ff0107" fill-rule="nonzero"/><path d="M511.862 303.22h-94.495l.046-94.446h94.496l-.047 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M511.907 208.77h-94.493l.044-94.445h94.496l-.047 94.446z" fill="#ffa300" fill-rule="nonzero"/><path d="M511.815 398.666h-94.493l.044-95.447h94.496l-.047 95.447z" fill="#ff4900" fill-rule="nonzero"/><path d="M511.77 492.117h-94.496l.046-94.448h94.496l-.047 94.448z" fill="#ff0107" fill-rule="nonzero"/></svg>
|
||||
|
Before Width: | Height: | Size: 2.9 KiB |
@@ -75,8 +75,7 @@ export default function Page() {
|
||||
},
|
||||
{} as Record<SourceCategory, SourceMetadata[]>
|
||||
);
|
||||
}, [sources, filterSources, searchTerm]);
|
||||
|
||||
}, [sources, searchTerm]);
|
||||
const handleKeyPress = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === "Enter") {
|
||||
const filteredCategories = Object.entries(categorizedSources).filter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils";
|
||||
|
||||
import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Button } from "@/components/ui/button";
|
||||
@@ -9,11 +9,12 @@ import { Textarea } from "@/components/ui/textarea";
|
||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||
import {
|
||||
ArrayHelpers,
|
||||
ErrorMessage,
|
||||
Field,
|
||||
FieldArray,
|
||||
Form,
|
||||
Formik,
|
||||
FormikProps,
|
||||
useFormikContext,
|
||||
} from "formik";
|
||||
|
||||
import {
|
||||
@@ -26,6 +27,7 @@ import {
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
|
||||
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
@@ -39,9 +41,10 @@ import {
|
||||
} from "@/components/ui/tooltip";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { FiInfo, FiRefreshCcw } from "react-icons/fi";
|
||||
import { useEffect, useState } from "react";
|
||||
import { FiInfo, FiX } from "react-icons/fi";
|
||||
import * as Yup from "yup";
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
|
||||
@@ -63,9 +66,6 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
||||
import { LlmList } from "@/components/llm/LLMList";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { debounce } from "lodash";
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import StarterMessagesList from "./StarterMessageList";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { CategoryCard } from "./CategoryCard";
|
||||
|
||||
@@ -129,14 +129,12 @@ export function AssistantEditor({
|
||||
];
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const [hasEditedStarterMessage, setHasEditedStarterMessage] = useState(false);
|
||||
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
|
||||
|
||||
// state to persist across formik reformatting
|
||||
const [defautIconColor, _setDeafultIconColor] = useState(
|
||||
colorOptions[Math.floor(Math.random() * colorOptions.length)]
|
||||
);
|
||||
const [isRefreshing, setIsRefreshing] = useState(false);
|
||||
|
||||
const [defaultIconShape, setDefaultIconShape] = useState<any>(null);
|
||||
|
||||
@@ -150,10 +148,6 @@ export function AssistantEditor({
|
||||
|
||||
const [removePersonaImage, setRemovePersonaImage] = useState(false);
|
||||
|
||||
const autoStarterMessageEnabled = useMemo(
|
||||
() => llmProviders.length > 0,
|
||||
[llmProviders.length]
|
||||
);
|
||||
const isUpdate = existingPersona !== undefined && existingPersona !== null;
|
||||
const existingPrompt = existingPersona?.prompts[0] ?? null;
|
||||
const defaultProvider = llmProviders.find(
|
||||
@@ -223,24 +217,7 @@ export function AssistantEditor({
|
||||
existingPersona?.llm_model_provider_override ?? null,
|
||||
llm_model_version_override:
|
||||
existingPersona?.llm_model_version_override ?? null,
|
||||
starter_messages: existingPersona?.starter_messages ?? [
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "",
|
||||
message: "",
|
||||
},
|
||||
],
|
||||
starter_messages: existingPersona?.starter_messages ?? [],
|
||||
enabled_tools_map: enabledToolsMap,
|
||||
icon_color: existingPersona?.icon_color ?? defautIconColor,
|
||||
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
|
||||
@@ -251,44 +228,6 @@ export function AssistantEditor({
|
||||
groups: existingPersona?.groups ?? [],
|
||||
};
|
||||
|
||||
interface AssistantPrompt {
|
||||
message: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
const debouncedRefreshPrompts = debounce(
|
||||
async (values: any, setFieldValue: any) => {
|
||||
if (!autoStarterMessageEnabled) {
|
||||
return;
|
||||
}
|
||||
setIsRefreshing(true);
|
||||
try {
|
||||
const response = await fetch("/api/persona/assistant-prompt-refresh", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: values.name,
|
||||
description: values.description,
|
||||
document_set_ids: values.document_set_ids,
|
||||
instructions: values.system_prompt || values.task_prompt,
|
||||
}),
|
||||
});
|
||||
|
||||
const data: AssistantPrompt = await response.json();
|
||||
if (response.ok) {
|
||||
setFieldValue("starter_messages", data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to refresh prompts:", error);
|
||||
} finally {
|
||||
setIsRefreshing(false);
|
||||
}
|
||||
},
|
||||
1000
|
||||
);
|
||||
|
||||
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
|
||||
|
||||
return (
|
||||
@@ -482,8 +421,6 @@ export function AssistantEditor({
|
||||
isSubmitting,
|
||||
values,
|
||||
setFieldValue,
|
||||
errors,
|
||||
|
||||
...formikProps
|
||||
}: FormikProps<any>) => {
|
||||
function toggleToolInValues(toolId: number) {
|
||||
@@ -508,7 +445,6 @@ export function AssistantEditor({
|
||||
|
||||
return (
|
||||
<Form className="w-full text-text-950">
|
||||
{/* Refresh starter messages when name or description changes */}
|
||||
<div className="w-full flex gap-x-2 justify-center">
|
||||
<Popover
|
||||
open={isIconDropdownOpen}
|
||||
@@ -1048,91 +984,6 @@ export function AssistantEditor({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-6 w-full flex flex-col">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Starter Messages
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<SubLabel>
|
||||
Pre-configured messages that help users understand what this
|
||||
assistant can do and how to interact with it effectively.
|
||||
</SubLabel>
|
||||
<div className="relative w-fit">
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Button
|
||||
type="button"
|
||||
size="sm"
|
||||
onClick={() =>
|
||||
debouncedRefreshPrompts(values, setFieldValue)
|
||||
}
|
||||
disabled={
|
||||
!autoStarterMessageEnabled ||
|
||||
isRefreshing ||
|
||||
(Object.keys(errors).length > 0 &&
|
||||
Object.keys(errors).some(
|
||||
(key) => !key.startsWith("starter_messages")
|
||||
))
|
||||
}
|
||||
className={`
|
||||
px-3 py-2
|
||||
mr-auto
|
||||
my-2
|
||||
flex gap-x-2
|
||||
text-sm font-medium
|
||||
rounded-lg shadow-sm
|
||||
items-center gap-2
|
||||
transition-colors duration-200
|
||||
${
|
||||
isRefreshing || !autoStarterMessageEnabled
|
||||
? "bg-gray-100 text-gray-400 cursor-not-allowed"
|
||||
: "bg-blue-50 text-blue-600 hover:bg-blue-100 active:bg-blue-200"
|
||||
}
|
||||
`}
|
||||
>
|
||||
<div className="flex items-center gap-x-2">
|
||||
{isRefreshing ? (
|
||||
<FiRefreshCcw className="w-4 h-4 animate-spin text-gray-400" />
|
||||
) : (
|
||||
<SwapIcon className="w-4 h-4 text-blue-600" />
|
||||
)}
|
||||
Generate
|
||||
</div>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!autoStarterMessageEnabled && (
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
No LLM providers configured. Generation is not
|
||||
available.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<FieldArray
|
||||
name="starter_messages"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<StarterMessagesList
|
||||
isRefreshing={isRefreshing}
|
||||
values={values.starter_messages}
|
||||
arrayHelpers={arrayHelpers}
|
||||
touchStarterMessages={() => {
|
||||
setHasEditedStarterMessage(true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{admin && (
|
||||
<AdvancedOptionsToggle
|
||||
title="Categories"
|
||||
@@ -1339,12 +1190,136 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="mb-6 flex flex-col">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Starter Messages (Optional){" "}
|
||||
</div>
|
||||
</div>
|
||||
<SubLabel>
|
||||
Add pre-defined messages to help users get started. Only
|
||||
the first 4 will be displayed.
|
||||
</SubLabel>
|
||||
<FieldArray
|
||||
name="starter_messages"
|
||||
render={(
|
||||
arrayHelpers: ArrayHelpers<StarterMessage[]>
|
||||
) => (
|
||||
<div>
|
||||
{values.starter_messages &&
|
||||
values.starter_messages.length > 0 &&
|
||||
values.starter_messages.map(
|
||||
(
|
||||
starterMessage: StarterMessage,
|
||||
index: number
|
||||
) => {
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={index === 0 ? "mt-2" : "mt-6"}
|
||||
>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label small>Name</Label>
|
||||
<SubLabel>
|
||||
Shows up as the "title"
|
||||
for this Starter Message. For
|
||||
example, "Write an email".
|
||||
</SubLabel>
|
||||
<Field
|
||||
name={`starter_messages[${index}].name`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages[${index}].name`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-3">
|
||||
<Label small>Message</Label>
|
||||
<SubLabel>
|
||||
The actual message to be sent as the
|
||||
initial user message if a user
|
||||
selects this starter prompt. For
|
||||
example, "Write me an email to
|
||||
a client about a new billing feature
|
||||
we just released."
|
||||
</SubLabel>
|
||||
<Field
|
||||
name={`starter_messages[${index}].message`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
min-h-12
|
||||
mr-4
|
||||
line-clamp-
|
||||
`}
|
||||
as="textarea"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages[${index}].message`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() =>
|
||||
arrayHelpers.remove(index)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
)}
|
||||
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
description: "",
|
||||
message: "",
|
||||
});
|
||||
}}
|
||||
className="mt-3"
|
||||
size="sm"
|
||||
variant="next"
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<IsPublicGroupSelector
|
||||
formikProps={{
|
||||
values,
|
||||
isSubmitting,
|
||||
setFieldValue,
|
||||
errors,
|
||||
...formikProps,
|
||||
}}
|
||||
objectName="assistant"
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { ArrayHelpers, ErrorMessage, Field, useFormikContext } from "formik";
|
||||
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@radix-ui/react-tooltip";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { FiInfo, FiTrash2, FiPlus } from "react-icons/fi";
|
||||
import { StarterMessage } from "./interfaces";
|
||||
import { Label } from "@/components/admin/connectors/Field";
|
||||
|
||||
export default function StarterMessagesList({
|
||||
values,
|
||||
arrayHelpers,
|
||||
isRefreshing,
|
||||
touchStarterMessages,
|
||||
}: {
|
||||
values: StarterMessage[];
|
||||
arrayHelpers: ArrayHelpers;
|
||||
isRefreshing: boolean;
|
||||
touchStarterMessages: () => void;
|
||||
}) {
|
||||
const { handleChange } = useFormikContext();
|
||||
|
||||
// Group starter messages into rows of 2 for display purposes
|
||||
const rows = values.reduce((acc: StarterMessage[][], curr, i) => {
|
||||
if (i % 2 === 0) acc.push([curr]);
|
||||
else acc[acc.length - 1].push(curr);
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
const canAddMore = values.length <= 6;
|
||||
|
||||
return (
|
||||
<div className="mt-4 flex flex-col gap-6">
|
||||
{rows.map((row, rowIndex) => (
|
||||
<div key={rowIndex} className="flex items-start gap-4">
|
||||
<div className="grid grid-cols-2 gap-6 w-full xl:w-fit">
|
||||
{row.map((starterMessage, colIndex) => (
|
||||
<div
|
||||
key={rowIndex * 2 + colIndex}
|
||||
className="bg-white max-w-full w-full xl:w-[500px] border border-border rounded-lg shadow-md transition-shadow duration-200 p-6"
|
||||
>
|
||||
<div className="space-y-5">
|
||||
{isRefreshing ? (
|
||||
<div className="w-full">
|
||||
<div className="w-full">
|
||||
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
|
||||
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
|
||||
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
|
||||
<div className="h-24 w-full bg-gray-200 rounded animate-pulse" />
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div>
|
||||
<div className="flex w-full items-center gap-x-1">
|
||||
<Label
|
||||
small
|
||||
className="text-sm font-medium text-gray-700"
|
||||
>
|
||||
Name
|
||||
</Label>
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<FiInfo size={12} />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
Shows up as the "title" for this
|
||||
Starter Message. For example, "Write an
|
||||
email."
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<Field
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.name`}
|
||||
className="mt-1 w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition"
|
||||
autoComplete="off"
|
||||
placeholder="Enter a name..."
|
||||
onChange={(e: any) => {
|
||||
touchStarterMessages();
|
||||
handleChange(e);
|
||||
}}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.name`}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<div className="flex w-full items-center gap-x-1">
|
||||
<Label
|
||||
small
|
||||
className="text-sm font-medium text-gray-700"
|
||||
>
|
||||
Message
|
||||
</Label>
|
||||
<TooltipProvider delayDuration={50}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<FiInfo size={12} />
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="top" align="center">
|
||||
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
|
||||
The actual message to be sent as the initial
|
||||
user message.
|
||||
</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
<Field
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.message`}
|
||||
className="mt-1 text-sm w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition min-h-[100px] resize-y"
|
||||
as="textarea"
|
||||
autoComplete="off"
|
||||
placeholder="Enter the message..."
|
||||
onChange={(e: any) => {
|
||||
touchStarterMessages();
|
||||
handleChange(e);
|
||||
}}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`starter_messages.${
|
||||
rowIndex * 2 + colIndex
|
||||
}.message`}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
arrayHelpers.remove(rowIndex * 2 + 1);
|
||||
arrayHelpers.remove(rowIndex * 2);
|
||||
}}
|
||||
className="p-1.5 bg-white border border-gray-200 rounded-full text-gray-400 hover:text-red-500 hover:border-red-200 transition-colors mt-2"
|
||||
aria-label="Delete row"
|
||||
>
|
||||
<FiTrash2 size={14} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{canAddMore && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
message: "",
|
||||
});
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
message: "",
|
||||
});
|
||||
}}
|
||||
className="self-start flex items-center gap-2 px-4 py-2 bg-white border border-gray-200 rounded-lg text-gray-600 hover:bg-gray-50 hover:border-gray-300 transition-colors"
|
||||
>
|
||||
<FiPlus size={16} />
|
||||
<span>Add Row</span>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,12 +1,8 @@
|
||||
import {
|
||||
AnthropicIcon,
|
||||
AmazonIcon,
|
||||
AWSIcon,
|
||||
AzureIcon,
|
||||
CPUIcon,
|
||||
MicrosoftIconSVG,
|
||||
MistralIcon,
|
||||
MetaIcon,
|
||||
OpenAIIcon,
|
||||
GeminiIcon,
|
||||
OpenSourceIcon,
|
||||
@@ -76,25 +72,12 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
|
||||
switch (providerName) {
|
||||
case "openai":
|
||||
// Special cases for openai based on modelName
|
||||
if (modelName?.toLowerCase().includes("amazon")) {
|
||||
return AmazonIcon;
|
||||
}
|
||||
if (modelName?.toLowerCase().includes("phi")) {
|
||||
return MicrosoftIconSVG;
|
||||
}
|
||||
if (modelName?.toLowerCase().includes("mistral")) {
|
||||
return MistralIcon;
|
||||
}
|
||||
if (modelName?.toLowerCase().includes("llama")) {
|
||||
return MetaIcon;
|
||||
}
|
||||
if (modelName?.toLowerCase().includes("gemini")) {
|
||||
return GeminiIcon;
|
||||
}
|
||||
if (modelName?.toLowerCase().includes("claude")) {
|
||||
return AnthropicIcon;
|
||||
}
|
||||
|
||||
return OpenAIIcon; // Default for openai
|
||||
case "anthropic":
|
||||
return AnthropicIcon;
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useFormContext } from "@/components/context/FormContext";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
|
||||
import { SettingsIcon } from "@/components/icons/icons";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { credentialTemplates } from "@/lib/connectors/credentials";
|
||||
import Link from "next/link";
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useCallback, useEffect, useState } from "react";
|
||||
import Text from "@/components/ui/text";
|
||||
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
|
||||
import { User } from "@/lib/types";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { Logo } from "@/components/Logo";
|
||||
|
||||
export function Verify({ user }: { user: User | null }) {
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
@@ -8,7 +8,7 @@ import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { User } from "@/lib/types";
|
||||
import Text from "@/components/ui/text";
|
||||
import { RequestNewVerificationEmail } from "./RequestNewVerificationEmail";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { Logo } from "@/components/Logo";
|
||||
|
||||
export default async function Page() {
|
||||
// catch cases where the backend is completely unreachable here
|
||||
|
||||
@@ -27,7 +27,6 @@ import {
|
||||
buildLatestMessageChain,
|
||||
checkAnyAssistantHasSearch,
|
||||
createChatSession,
|
||||
deleteAllChatSessions,
|
||||
deleteChatSession,
|
||||
getCitedDocumentsFromMessage,
|
||||
getHumanAndAIMessageFromMessageNumber,
|
||||
@@ -273,7 +272,6 @@ export function ChatPage({
|
||||
};
|
||||
|
||||
const llmOverrideManager = useLlmOverride(
|
||||
llmProviders,
|
||||
modelVersionFromSearchParams || (user?.preferences.default_model ?? null),
|
||||
selectedChatSession,
|
||||
defaultTemperature
|
||||
@@ -320,9 +318,9 @@ export function ChatPage({
|
||||
);
|
||||
|
||||
if (personaDefault) {
|
||||
llmOverrideManager.updateLLMOverride(personaDefault);
|
||||
llmOverrideManager.setLlmOverride(personaDefault);
|
||||
} else if (user?.preferences.default_model) {
|
||||
llmOverrideManager.updateLLMOverride(
|
||||
llmOverrideManager.setLlmOverride(
|
||||
destructureValue(user?.preferences.default_model)
|
||||
);
|
||||
}
|
||||
@@ -1204,6 +1202,7 @@ export function ChatPage({
|
||||
assistant_message_id: number;
|
||||
frozenMessageMap: Map<number, Message>;
|
||||
} = null;
|
||||
|
||||
try {
|
||||
const mapKeys = Array.from(
|
||||
currentMessageMap(completeMessageDetail).keys()
|
||||
@@ -1838,7 +1837,6 @@ export function ChatPage({
|
||||
|
||||
const innerSidebarElementRef = useRef<HTMLDivElement>(null);
|
||||
const [settingsToggled, setSettingsToggled] = useState(false);
|
||||
const [showDeleteAllModal, setShowDeleteAllModal] = useState(false);
|
||||
|
||||
const currentPersona = alternativeAssistant || liveAssistant;
|
||||
useEffect(() => {
|
||||
@@ -1905,6 +1903,11 @@ export function ChatPage({
|
||||
const showShareModal = (chatSession: ChatSession) => {
|
||||
setSharedChatSession(chatSession);
|
||||
};
|
||||
const [documentSelection, setDocumentSelection] = useState(false);
|
||||
// const toggleDocumentSelectionAspects = () => {
|
||||
// setDocumentSelection((documentSelection) => !documentSelection);
|
||||
// setShowDocSidebar(false);
|
||||
// };
|
||||
|
||||
const toggleDocumentSidebar = () => {
|
||||
if (!documentSidebarToggled) {
|
||||
@@ -1969,32 +1972,6 @@ export function ChatPage({
|
||||
|
||||
<ChatPopup />
|
||||
|
||||
{showDeleteAllModal && (
|
||||
<DeleteEntityModal
|
||||
entityType="All Chats"
|
||||
entityName="all your chat sessions"
|
||||
onClose={() => setShowDeleteAllModal(false)}
|
||||
additionalDetails="This action cannot be undone. All your chat sessions will be deleted."
|
||||
onSubmit={async () => {
|
||||
const response = await deleteAllChatSessions("Chat");
|
||||
if (response.ok) {
|
||||
setShowDeleteAllModal(false);
|
||||
setPopup({
|
||||
message: "All your chat sessions have been deleted.",
|
||||
type: "success",
|
||||
});
|
||||
refreshChatSessions();
|
||||
router.push("/chat");
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Failed to delete all chat sessions.",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{currentFeedback && (
|
||||
<FeedbackModal
|
||||
feedbackType={currentFeedback[0]}
|
||||
@@ -2146,7 +2123,7 @@ export function ChatPage({
|
||||
page="chat"
|
||||
ref={innerSidebarElementRef}
|
||||
toggleSidebar={toggleSidebar}
|
||||
toggled={toggledSidebar}
|
||||
toggled={toggledSidebar && !settings?.isMobile}
|
||||
backgroundToggled={toggledSidebar || showHistorySidebar}
|
||||
existingChats={chatSessions}
|
||||
currentChatSession={selectedChatSession}
|
||||
@@ -2155,7 +2132,6 @@ export function ChatPage({
|
||||
removeToggle={removeToggle}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
showDeleteAllModal={() => setShowDeleteAllModal(true)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -2168,6 +2144,7 @@ export function ChatPage({
|
||||
fixed
|
||||
right-0
|
||||
z-[1000]
|
||||
|
||||
bg-background
|
||||
h-screen
|
||||
transition-all
|
||||
@@ -2217,6 +2194,8 @@ export function ChatPage({
|
||||
{liveAssistant && (
|
||||
<FunctionalHeader
|
||||
toggleUserSettings={() => setUserSettingsToggled(true)}
|
||||
liveAssistant={liveAssistant}
|
||||
onAssistantChange={onAssistantChange}
|
||||
sidebarToggled={toggledSidebar}
|
||||
reset={() => setMessage("")}
|
||||
page="chat"
|
||||
@@ -2228,6 +2207,7 @@ export function ChatPage({
|
||||
toggleSidebar={toggleSidebar}
|
||||
currentChatSession={selectedChatSession}
|
||||
documentSidebarToggled={documentSidebarToggled}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -2762,10 +2742,6 @@ export function ChatPage({
|
||||
removeDocs={() => {
|
||||
clearSelectedDocuments();
|
||||
}}
|
||||
showDocs={() => {
|
||||
setFiltersToggled(false);
|
||||
setDocumentSidebarToggled(true);
|
||||
}}
|
||||
removeFilters={() => {
|
||||
filterManager.setSelectedSources([]);
|
||||
filterManager.setSelectedTags([]);
|
||||
@@ -2778,6 +2754,7 @@ export function ChatPage({
|
||||
chatState={currentSessionChatState}
|
||||
stopGenerating={stopGenerating}
|
||||
openModelSettings={() => setSettingsToggled(true)}
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
selectedDocuments={selectedDocuments}
|
||||
// assistant stuff
|
||||
selectedAssistant={liveAssistant}
|
||||
|
||||
@@ -124,9 +124,9 @@ export default function RegenerateOption({
|
||||
onHoverChange: (isHovered: boolean) => void;
|
||||
onDropdownVisibleChange: (isVisible: boolean) => void;
|
||||
}) {
|
||||
const { llmProviders } = useChatContext();
|
||||
const llmOverrideManager = useLlmOverride(llmProviders);
|
||||
const llmOverrideManager = useLlmOverride();
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
|
||||
|
||||
const llmOptionsByProvider: {
|
||||
|
||||
@@ -81,8 +81,6 @@ export function ChatDocumentDisplay({
|
||||
}
|
||||
};
|
||||
|
||||
const hasMetadata =
|
||||
document.updated_at || Object.keys(document.metadata).length > 0;
|
||||
return (
|
||||
<div className={`opacity-100 ${modal ? "w-[90vw]" : "w-full"}`}>
|
||||
<div
|
||||
@@ -109,14 +107,8 @@ export function ChatDocumentDisplay({
|
||||
: document.semantic_identifier || document.document_id}
|
||||
</div>
|
||||
</div>
|
||||
{hasMetadata && (
|
||||
<DocumentMetadataBlock modal={modal} document={document} />
|
||||
)}
|
||||
<div
|
||||
className={`line-clamp-3 text-sm font-normal leading-snug text-gray-600 ${
|
||||
hasMetadata ? "mt-2" : ""
|
||||
}`}
|
||||
>
|
||||
<DocumentMetadataBlock modal={modal} document={document} />
|
||||
<div className="line-clamp-3 pt-2 text-sm font-normal leading-snug text-gray-600">
|
||||
{buildDocumentSummaryDisplay(
|
||||
document.match_highlights,
|
||||
document.blurb
|
||||
|
||||
@@ -31,7 +31,14 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { ChatState } from "../types";
|
||||
import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import AnimatedToggle from "@/components/search/SearchBar";
|
||||
import { Popup } from "@/components/admin/connectors/Popup";
|
||||
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
|
||||
import { IconType } from "react-icons";
|
||||
import { LlmTab } from "../modal/configuration/LlmTab";
|
||||
import { XIcon } from "lucide-react";
|
||||
import { FilterPills } from "./FilterPills";
|
||||
import { Tag } from "@/lib/types";
|
||||
import FiltersDisplay from "./FilterDisplay";
|
||||
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
@@ -40,6 +47,7 @@ interface ChatInputBarProps {
|
||||
removeFilters: () => void;
|
||||
removeDocs: () => void;
|
||||
openModelSettings: () => void;
|
||||
showDocs: () => void;
|
||||
showConfigureAPIKey: () => void;
|
||||
selectedDocuments: OnyxDocument[];
|
||||
message: string;
|
||||
@@ -49,7 +57,6 @@ interface ChatInputBarProps {
|
||||
filterManager: FilterManager;
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
chatState: ChatState;
|
||||
showDocs: () => void;
|
||||
alternativeAssistant: Persona | null;
|
||||
// assistants
|
||||
selectedAssistant: Persona;
|
||||
@@ -68,8 +75,8 @@ export function ChatInputBar({
|
||||
removeFilters,
|
||||
removeDocs,
|
||||
openModelSettings,
|
||||
showConfigureAPIKey,
|
||||
showDocs,
|
||||
showConfigureAPIKey,
|
||||
selectedDocuments,
|
||||
message,
|
||||
setMessage,
|
||||
@@ -277,6 +284,10 @@ export function ChatInputBar({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* <div>
|
||||
<SelectedFilterDisplay filterManager={filterManager} />
|
||||
</div> */}
|
||||
|
||||
<UnconfiguredProviderText showConfigureAPIKey={showConfigureAPIKey} />
|
||||
|
||||
<div
|
||||
@@ -417,7 +428,9 @@ export function ChatInputBar({
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
role="textarea"
|
||||
aria-multiline
|
||||
placeholder="Ask me anything.."
|
||||
placeholder={`Send a message ${
|
||||
!settings?.isMobile ? "or try using @ or /" : ""
|
||||
}`}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
if (
|
||||
|
||||
@@ -278,16 +278,6 @@ export async function deleteChatSession(chatSessionId: string) {
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function deleteAllChatSessions(sessionType: "Chat" | "Search") {
|
||||
const response = await fetch(`/api/chat/delete-all-chat-sessions`, {
|
||||
method: "DELETE",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function* simulateLLMResponse(input: string, delay: number = 30) {
|
||||
// Split the input string into tokens. This is a simple example, and in real use case, tokenization can be more complex.
|
||||
// Iterate over tokens and yield them one by one
|
||||
|
||||
@@ -35,7 +35,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
checkPersonaRequiresImageGeneration(currentAssistant);
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
const { updateLLMOverride, temperature, updateTemperature } =
|
||||
const { setLlmOverride, temperature, updateTemperature } =
|
||||
llmOverrideManager;
|
||||
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
|
||||
|
||||
@@ -60,7 +60,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
|
||||
if (value == null) {
|
||||
return;
|
||||
}
|
||||
updateLLMOverride(destructureValue(value));
|
||||
setLlmOverride(destructureValue(value));
|
||||
if (chatSessionId) {
|
||||
updateModelOverrideForChatSession(chatSessionId, value as string);
|
||||
}
|
||||
|
||||
@@ -11,10 +11,13 @@ import { createFolder } from "../folders/FolderManagement";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
|
||||
import { AssistantsIconSkeleton } from "@/components/icons/icons";
|
||||
import {
|
||||
AssistantsIconSkeleton,
|
||||
ClosedBookIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { PagesTab } from "./PagesTab";
|
||||
import { pageType } from "./types";
|
||||
import LogoWithText from "@/components/header/LogoWithText";
|
||||
import LogoType from "@/components/header/LogoType";
|
||||
|
||||
interface HistorySidebarProps {
|
||||
page: pageType;
|
||||
@@ -30,7 +33,6 @@ interface HistorySidebarProps {
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
stopGenerating?: () => void;
|
||||
explicitlyUntoggle: () => void;
|
||||
showDeleteAllModal?: () => void;
|
||||
backgroundToggled?: boolean;
|
||||
}
|
||||
|
||||
@@ -50,7 +52,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
stopGenerating = () => null,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
showDeleteAllModal,
|
||||
backgroundToggled,
|
||||
},
|
||||
ref: ForwardedRef<HTMLDivElement>
|
||||
@@ -99,19 +100,16 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
flex
|
||||
flex-col relative
|
||||
h-screen
|
||||
pt-2
|
||||
transition-transform
|
||||
`}
|
||||
>
|
||||
<div className="pl-2">
|
||||
<LogoWithText
|
||||
showArrow={true}
|
||||
toggled={toggled}
|
||||
page={page}
|
||||
toggleSidebar={toggleSidebar}
|
||||
explicitlyUntoggle={explicitlyUntoggle}
|
||||
/>
|
||||
</div>
|
||||
<LogoType
|
||||
showArrow={true}
|
||||
toggled={toggled}
|
||||
page={page}
|
||||
toggleSidebar={toggleSidebar}
|
||||
explicitlyUntoggle={explicitlyUntoggle}
|
||||
/>
|
||||
{page == "chat" && (
|
||||
<div className="mx-3 mt-4 gap-y-1 flex-col text-text-history-sidebar-button flex gap-x-1.5 items-center items-center">
|
||||
<Link
|
||||
@@ -178,7 +176,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
currentChatId={currentChatId}
|
||||
folders={folders}
|
||||
openedFolders={openedFolders}
|
||||
showDeleteAllModal={showDeleteAllModal}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -9,8 +9,6 @@ import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState } from "react";
|
||||
import { pageType } from "./types";
|
||||
import { FiTrash2 } from "react-icons/fi";
|
||||
import { NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED } from "@/lib/constants";
|
||||
|
||||
export function PagesTab({
|
||||
page,
|
||||
@@ -22,7 +20,6 @@ export function PagesTab({
|
||||
newFolderId,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
showDeleteAllModal,
|
||||
}: {
|
||||
page: pageType;
|
||||
existingChats?: ChatSession[];
|
||||
@@ -33,7 +30,6 @@ export function PagesTab({
|
||||
newFolderId: number | null;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteAllModal?: () => void;
|
||||
}) {
|
||||
const groupedChatSessions = existingChats
|
||||
? groupSessionsByDateRange(existingChats)
|
||||
@@ -67,98 +63,82 @@ export function PagesTab({
|
||||
const isHistoryEmpty = !existingChats || existingChats.length === 0;
|
||||
|
||||
return (
|
||||
<div className="flex flex-col relative h-full overflow-y-auto mb-1 ml-3 miniscroll mobile:pb-40">
|
||||
<div
|
||||
className={` flex-grow overflow-y-auto ${
|
||||
NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED && "pb-20 "
|
||||
}`}
|
||||
>
|
||||
{folders && folders.length > 0 && (
|
||||
<div className="py-2 border-b border-border">
|
||||
<div className="text-xs text-subtle flex pb-0.5 mb-1.5 mt-2 font-bold">
|
||||
Chat Folders
|
||||
</div>
|
||||
<FolderList
|
||||
newFolderId={newFolderId}
|
||||
folders={folders}
|
||||
currentChatId={currentChatId}
|
||||
openedFolders={openedFolders}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
<div className="mb-1 text-text-sidebar ml-3 relative miniscroll mobile:pb-40 overflow-y-auto h-full">
|
||||
{folders && folders.length > 0 && (
|
||||
<div className="py-2 border-b border-border">
|
||||
<div className="text-xs text-subtle flex pb-0.5 mb-1.5 mt-2 font-bold">
|
||||
Chat Folders
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
onDragOver={(event) => {
|
||||
event.preventDefault();
|
||||
setIsDragOver(true);
|
||||
}}
|
||||
onDragLeave={() => setIsDragOver(false)}
|
||||
onDrop={handleDropToRemoveFromFolder}
|
||||
className={`pt-1 transition duration-300 ease-in-out mr-3 ${
|
||||
isDragOver ? "bg-hover" : ""
|
||||
} rounded-md`}
|
||||
>
|
||||
{(page == "chat" || page == "search") && (
|
||||
<p className="my-2 text-xs text-sidebar-subtle flex font-bold">
|
||||
{page == "chat" && "Chat "}
|
||||
{page == "search" && "Search "}
|
||||
History
|
||||
</p>
|
||||
)}
|
||||
{isHistoryEmpty ? (
|
||||
<p className="text-sm mt-2 w-[250px]">
|
||||
Try sending a message! Your chat history will appear here.
|
||||
</p>
|
||||
) : (
|
||||
Object.entries(groupedChatSessions).map(
|
||||
([dateRange, chatSessions], ind) => {
|
||||
if (chatSessions.length > 0) {
|
||||
return (
|
||||
<div key={dateRange}>
|
||||
<div
|
||||
className={`text-xs text-text-sidebar-subtle ${
|
||||
ind != 0 && "mt-5"
|
||||
} flex pb-0.5 mb-1.5 font-medium`}
|
||||
>
|
||||
{dateRange}
|
||||
</div>
|
||||
{chatSessions
|
||||
.filter((chat) => chat.folder_id === null)
|
||||
.map((chat) => {
|
||||
const isSelected = currentChatId === chat.id;
|
||||
return (
|
||||
<div key={`${chat.id}-${chat.name}`}>
|
||||
<ChatSessionDisplay
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
closeSidebar={closeSidebar}
|
||||
search={page == "search"}
|
||||
chatSession={chat}
|
||||
isSelected={isSelected}
|
||||
skipGradient={isDragOver}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
)
|
||||
)}
|
||||
<FolderList
|
||||
newFolderId={newFolderId}
|
||||
folders={folders}
|
||||
currentChatId={currentChatId}
|
||||
openedFolders={openedFolders}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
</div>
|
||||
{showDeleteAllModal && NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED && (
|
||||
<div className="absolute w-full border-t border-t-border bg-background-100 bottom-0 left-0 p-4">
|
||||
<button
|
||||
className="w-full py-2 px-4 text-text-600 hover:text-text-800 bg-background-125 border border-border-strong/50 shadow-sm rounded-md transition-colors duration-200 flex items-center justify-center text-sm"
|
||||
onClick={showDeleteAllModal}
|
||||
>
|
||||
<FiTrash2 className="mr-2" size={14} />
|
||||
Clear All History
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
onDragOver={(event) => {
|
||||
event.preventDefault();
|
||||
setIsDragOver(true);
|
||||
}}
|
||||
onDragLeave={() => setIsDragOver(false)}
|
||||
onDrop={handleDropToRemoveFromFolder}
|
||||
className={`pt-1 transition duration-300 ease-in-out mr-3 ${
|
||||
isDragOver ? "bg-hover" : ""
|
||||
} rounded-md`}
|
||||
>
|
||||
{(page == "chat" || page == "search") && (
|
||||
<p className="my-2 text-xs text-sidebar-subtle flex font-bold">
|
||||
{page == "chat" && "Chat "}
|
||||
{page == "search" && "Search "}
|
||||
History
|
||||
</p>
|
||||
)}
|
||||
{isHistoryEmpty ? (
|
||||
<p className="text-sm mt-2 w-[250px]">
|
||||
{page === "search"
|
||||
? "Try running a search! Your search history will appear here."
|
||||
: "Try sending a message! Your chat history will appear here."}
|
||||
</p>
|
||||
) : (
|
||||
Object.entries(groupedChatSessions).map(
|
||||
([dateRange, chatSessions], ind) => {
|
||||
if (chatSessions.length > 0) {
|
||||
return (
|
||||
<div key={dateRange}>
|
||||
<div
|
||||
className={`text-xs text-text-sidebar-subtle ${
|
||||
ind != 0 && "mt-5"
|
||||
} flex pb-0.5 mb-1.5 font-medium`}
|
||||
>
|
||||
{dateRange}
|
||||
</div>
|
||||
{chatSessions
|
||||
.filter((chat) => chat.folder_id === null)
|
||||
.map((chat) => {
|
||||
const isSelected = currentChatId === chat.id;
|
||||
return (
|
||||
<div key={`${chat.id}-${chat.name}`}>
|
||||
<ChatSessionDisplay
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
closeSidebar={closeSidebar}
|
||||
search={page == "search"}
|
||||
chatSession={chat}
|
||||
isSelected={isSelected}
|
||||
skipGradient={isDragOver}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,74 +1,50 @@
|
||||
"use client";
|
||||
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
import Link from "next/link";
|
||||
import { useContext } from "react";
|
||||
import { FiSidebar } from "react-icons/fi";
|
||||
import { LogoType } from "@/components/logo/Logo";
|
||||
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function LogoComponent({
|
||||
enterpriseSettings,
|
||||
backgroundToggled,
|
||||
show,
|
||||
isAdmin,
|
||||
}: {
|
||||
enterpriseSettings: EnterpriseSettings | null;
|
||||
backgroundToggled?: boolean;
|
||||
show?: boolean;
|
||||
isAdmin?: boolean;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
return (
|
||||
<button
|
||||
onClick={isAdmin ? () => router.push("/chat") : () => {}}
|
||||
className={`max-w-[200px] ${
|
||||
!show && "mobile:hidden"
|
||||
} flex items-center gap-x-1`}
|
||||
>
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<>
|
||||
<div className="flex-none my-auto">
|
||||
<Logo height={24} width={24} />
|
||||
</div>
|
||||
<div className="w-full">
|
||||
<HeaderTitle backgroundToggled={backgroundToggled}>
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-left text-subtle">Powered by Onyx</p>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<LogoType />
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
|
||||
export default function FixedLogo({
|
||||
// Whether the sidebar is toggled or not
|
||||
backgroundToggled,
|
||||
}: {
|
||||
backgroundToggled?: boolean;
|
||||
}) {
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
const settings = combinedSettings?.settings;
|
||||
const enterpriseSettings = combinedSettings?.enterpriseSettings;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Link
|
||||
href="/chat"
|
||||
className="fixed cursor-pointer flex z-40 left-4 top-3 h-8"
|
||||
className="fixed cursor-pointer flex z-40 left-4 top-2 h-8"
|
||||
>
|
||||
<LogoComponent
|
||||
enterpriseSettings={enterpriseSettings!}
|
||||
backgroundToggled={backgroundToggled}
|
||||
/>
|
||||
<div className="max-w-[200px] mobile:hidden flex items-center gap-x-1 my-auto">
|
||||
<div className="flex-none my-auto">
|
||||
<Logo height={24} width={24} />
|
||||
</div>
|
||||
<div className="w-full">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<div>
|
||||
<HeaderTitle backgroundToggled={backgroundToggled}>
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle">Powered by Onyx</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<HeaderTitle backgroundToggled={backgroundToggled}>
|
||||
Onyx
|
||||
</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Link>
|
||||
<div className="mobile:hidden fixed left-4 bottom-4">
|
||||
<FiSidebar
|
||||
|
||||
@@ -14,6 +14,7 @@ import { buildClientUrl } from "@/lib/utilsSS";
|
||||
import { Inter } from "next/font/google";
|
||||
import { EnterpriseSettings, GatingType } from "./admin/settings/interfaces";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
|
||||
import { AppProvider } from "@/components/context/AppProvider";
|
||||
import { PHProvider } from "./providers";
|
||||
@@ -22,7 +23,6 @@ import CardSection from "@/components/admin/CardSection";
|
||||
import { Suspense } from "react";
|
||||
import PostHogPageView from "./PostHogPageView";
|
||||
import Script from "next/script";
|
||||
import { LogoType } from "@/components/logo/Logo";
|
||||
|
||||
const inter = Inter({
|
||||
subsets: ["latin"],
|
||||
@@ -115,7 +115,8 @@ export default async function RootLayout({
|
||||
return getPageContent(
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<div className="mb-2 flex items-center max-w-[175px]">
|
||||
<LogoType />
|
||||
<HeaderTitle>Onyx</HeaderTitle>
|
||||
<Logo height={40} width={40} />
|
||||
</div>
|
||||
|
||||
<CardSection className="max-w-md">
|
||||
@@ -123,8 +124,7 @@ export default async function RootLayout({
|
||||
<p className="text-text-500">
|
||||
Your Onyx instance was not configured properly and your settings
|
||||
could not be loaded. This could be due to an admin configuration
|
||||
issue, an incomplete setup, or backend services that may not be up
|
||||
and running yet.
|
||||
issue or an incomplete setup.
|
||||
</p>
|
||||
<p className="mt-4">
|
||||
If you're an admin, please check{" "}
|
||||
@@ -144,7 +144,7 @@ export default async function RootLayout({
|
||||
community on{" "}
|
||||
<a
|
||||
className="text-link"
|
||||
href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ"
|
||||
href="https://onyx.app?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
@@ -160,7 +160,8 @@ export default async function RootLayout({
|
||||
return getPageContent(
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<div className="mb-2 flex items-center max-w-[175px]">
|
||||
<LogoType />
|
||||
<HeaderTitle>Onyx</HeaderTitle>
|
||||
<Logo height={40} width={40} />
|
||||
</div>
|
||||
<CardSection className="w-full max-w-md">
|
||||
<h1 className="text-2xl font-bold mb-4 text-error">
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useContext } from "react";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { SettingsContext } from "./settings/SettingsProvider";
|
||||
import Image from "next/image";
|
||||
|
||||
export function Logo({
|
||||
@@ -45,10 +45,10 @@ export function Logo({
|
||||
);
|
||||
}
|
||||
|
||||
export function LogoType() {
|
||||
export default function LogoType() {
|
||||
return (
|
||||
<Image
|
||||
className="max-h-8 w-full mr-auto "
|
||||
className="max-h-8 mr-auto "
|
||||
src="/logotype.png"
|
||||
alt="Logo"
|
||||
width={2640}
|
||||
38
web/src/components/LogoType.tsx
Normal file
38
web/src/components/LogoType.tsx
Normal file
@@ -0,0 +1,38 @@
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
import { HeaderTitle } from "./header/HeaderTitle";
|
||||
import LogoType, { Logo } from "./Logo";
|
||||
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
|
||||
|
||||
export default function LogoTypeContainer({
|
||||
enterpriseSettings,
|
||||
}: {
|
||||
enterpriseSettings: EnterpriseSettings | null;
|
||||
}) {
|
||||
const onlyLogo =
|
||||
!enterpriseSettings ||
|
||||
!enterpriseSettings.use_custom_logo ||
|
||||
!enterpriseSettings.application_name;
|
||||
|
||||
return (
|
||||
<div className="flex justify-start items-start w-full gap-x-1 my-auto">
|
||||
<div className="flex-none w-fit mr-auto my-auto">
|
||||
{onlyLogo ? <LogoType /> : <Logo height={24} width={24} />}
|
||||
</div>
|
||||
|
||||
{!onlyLogo && (
|
||||
<div className="w-full">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<div>
|
||||
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle">Powered by Onyx</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<HeaderTitle>Onyx</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -27,9 +27,7 @@ export function MetadataBadge({
|
||||
size: 12,
|
||||
className: flexNone ? "flex-none" : "mr-0.5 my-auto",
|
||||
})}
|
||||
<p className="max-w-[6rem] text-ellipsis overflow-hidden truncate whitespace-nowrap">
|
||||
{value}lllaasfasdf
|
||||
</p>
|
||||
<div className="my-auto flex">{value}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Logo } from "./logo/Logo";
|
||||
import { Logo } from "./Logo";
|
||||
import { useContext } from "react";
|
||||
import { SettingsContext } from "./settings/SettingsProvider";
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"use client";
|
||||
import React, { useContext } from "react";
|
||||
import Link from "next/link";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
@@ -14,8 +14,6 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { CgArrowsExpandUpLeft } from "react-icons/cg";
|
||||
import LogoWithText from "@/components/header/LogoWithText";
|
||||
import { LogoComponent } from "@/app/chat/shared_chat_search/FixedLogo";
|
||||
|
||||
interface Item {
|
||||
name: string | JSX.Element;
|
||||
@@ -34,22 +32,36 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const settings = combinedSettings.settings;
|
||||
const enterpriseSettings = combinedSettings.enterpriseSettings;
|
||||
|
||||
return (
|
||||
<div className="text-text-settings-sidebar pl-0">
|
||||
<nav className="space-y-2">
|
||||
<div className="w-full ml-4 mt-1 h-8 justify-start mb-4 flex">
|
||||
<LogoComponent
|
||||
show={true}
|
||||
enterpriseSettings={enterpriseSettings!}
|
||||
backgroundToggled={false}
|
||||
isAdmin={true}
|
||||
/>
|
||||
<div className="w-full ml-4 h-8 justify-start mb-4 flex">
|
||||
<div className="flex items-center gap-x-1 my-auto">
|
||||
<div className="flex-none my-auto">
|
||||
<Logo height={24} width={24} />
|
||||
</div>
|
||||
<div className="w-full">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<div>
|
||||
<HeaderTitle>
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle">Powered by Onyx</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<HeaderTitle>Onyx</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex w-full justify-center">
|
||||
<Link href="/chat">
|
||||
<button className="text-sm hover:bg-background-settings-hover flex items-center block w-52 py-2.5 flex px-2 text-left hover:bg-opacity-80 cursor-pointer rounded">
|
||||
<button className="text-sm flex items-center block w-52 py-2.5 flex px-2 text-left hover:bg-opacity-80 cursor-pointer rounded">
|
||||
<CgArrowsExpandUpLeft className="my-auto" size={18} />
|
||||
<p className="ml-1 break-words line-clamp-2 ellipsis leading-none">
|
||||
Exit Admin
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Logo } from "../logo/Logo";
|
||||
import { Logo } from "../Logo";
|
||||
|
||||
export default function AuthFlowContainer({
|
||||
children,
|
||||
|
||||
@@ -199,7 +199,7 @@ const AssistantSelector = ({
|
||||
onSelect={(value: string | null) => {
|
||||
if (value == null) return;
|
||||
const { modelName, name, provider } = destructureValue(value);
|
||||
llmOverrideManager.updateLLMOverride({
|
||||
llmOverrideManager.setLlmOverride({
|
||||
name,
|
||||
provider,
|
||||
modelName,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
"use client";
|
||||
import { User } from "@/lib/types";
|
||||
import { UserDropdown } from "../UserDropdown";
|
||||
import { FiShare2 } from "react-icons/fi";
|
||||
import { SetStateAction, useContext, useEffect } from "react";
|
||||
import { NewChatIcon } from "../icons/icons";
|
||||
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import Link from "next/link";
|
||||
import { pageType } from "@/app/chat/sessionSidebar/types";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { ChatBanner } from "@/app/chat/ChatBanner";
|
||||
import LogoWithText from "../header/LogoWithText";
|
||||
import { NewChatIcon } from "../icons/icons";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import LogoType from "../header/LogoType";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LlmOverrideManager } from "@/lib/hooks";
|
||||
|
||||
export default function FunctionalHeader({
|
||||
page,
|
||||
@@ -19,6 +21,9 @@ export default function FunctionalHeader({
|
||||
toggleSidebar = () => null,
|
||||
reset = () => null,
|
||||
sidebarToggled,
|
||||
liveAssistant,
|
||||
onAssistantChange,
|
||||
llmOverrideManager,
|
||||
documentSidebarToggled,
|
||||
toggleUserSettings,
|
||||
}: {
|
||||
@@ -29,9 +34,11 @@ export default function FunctionalHeader({
|
||||
currentChatSession?: ChatSession | null | undefined;
|
||||
setSharingModalVisible?: (value: SetStateAction<boolean>) => void;
|
||||
toggleSidebar?: () => void;
|
||||
liveAssistant?: Persona;
|
||||
onAssistantChange?: (assistant: Persona) => void;
|
||||
llmOverrideManager?: LlmOverrideManager;
|
||||
toggleUserSettings?: () => void;
|
||||
}) {
|
||||
const settings = useContext(SettingsContext);
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: KeyboardEvent) => {
|
||||
if (event.metaKey || event.ctrlKey) {
|
||||
@@ -69,11 +76,10 @@ export default function FunctionalHeader({
|
||||
return (
|
||||
<div className="left-0 sticky top-0 z-20 w-full relative flex">
|
||||
<div className="items-end flex mt-2 cursor-pointer text-text-700 relative flex w-full">
|
||||
<LogoWithText
|
||||
<LogoType
|
||||
assistantId={currentChatSession?.persona_id}
|
||||
page={page}
|
||||
toggleSidebar={toggleSidebar}
|
||||
toggled={sidebarToggled && !settings?.isMobile}
|
||||
handleNewChat={handleNewChat}
|
||||
/>
|
||||
<div className="mt-2 flex w-full h-8">
|
||||
@@ -97,19 +103,18 @@ export default function FunctionalHeader({
|
||||
</div>
|
||||
|
||||
<div className="invisible">
|
||||
<LogoWithText
|
||||
<LogoType
|
||||
page={page}
|
||||
toggled={sidebarToggled}
|
||||
toggleSidebar={toggleSidebar}
|
||||
handleNewChat={handleNewChat}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="absolute right-0 mobile:top-2 desktop:top-0 flex">
|
||||
<div className="absolute right-0 top-0 flex gap-x-2">
|
||||
{setSharingModalVisible && (
|
||||
<div
|
||||
onClick={() => setSharingModalVisible(true)}
|
||||
className="mobile:hidden mr-2 my-auto rounded cursor-pointer hover:bg-hover-light"
|
||||
className="mobile:hidden my-auto rounded cursor-pointer hover:bg-hover-light"
|
||||
>
|
||||
<FiShare2 size="18" />
|
||||
</div>
|
||||
@@ -121,7 +126,7 @@ export default function FunctionalHeader({
|
||||
/>
|
||||
</div>
|
||||
<Link
|
||||
className="desktop:hidden ml-2 my-auto"
|
||||
className="desktop:hidden my-auto"
|
||||
href={
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useEmbeddingFormContext } from "@/components/context/EmbeddingContext";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
|
||||
import { SettingsIcon } from "@/components/icons/icons";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Link from "next/link";
|
||||
import { useContext } from "react";
|
||||
|
||||
@@ -10,8 +10,7 @@ export function HeaderTitle({
|
||||
backgroundToggled?: boolean;
|
||||
}) {
|
||||
const isString = typeof children === "string";
|
||||
const textSize =
|
||||
isString && children.length > 10 ? "text-lg mb-[4px] " : "text-2xl";
|
||||
const textSize = isString && children.length > 10 ? "text-xl" : "text-2xl";
|
||||
|
||||
return (
|
||||
<h1
|
||||
@@ -19,7 +18,7 @@ export function HeaderTitle({
|
||||
backgroundToggled
|
||||
? "text-text-sidebar-toggled-header"
|
||||
: "text-text-sidebar-header"
|
||||
} break-words text-left line-clamp-2 ellipsis text-strong overflow-hidden leading-none font-bold`}
|
||||
} break-words line-clamp-2 ellipsis text-strong overflow-visible leading-none font-bold`}
|
||||
>
|
||||
{children}
|
||||
</h1>
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
import { useContext } from "react";
|
||||
import { FiSidebar } from "react-icons/fi";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
|
||||
import {
|
||||
NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED,
|
||||
NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA,
|
||||
} from "@/lib/constants";
|
||||
import { LeftToLineIcon, NewChatIcon, RightToLineIcon } from "../icons/icons";
|
||||
import {
|
||||
Tooltip,
|
||||
@@ -11,11 +14,11 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { pageType } from "@/app/chat/sessionSidebar/types";
|
||||
import { Logo } from "../logo/Logo";
|
||||
import { Logo } from "../Logo";
|
||||
import { HeaderTitle } from "./HeaderTitle";
|
||||
import Link from "next/link";
|
||||
import { LogoComponent } from "@/app/chat/shared_chat_search/FixedLogo";
|
||||
|
||||
export default function LogoWithText({
|
||||
export default function LogoType({
|
||||
toggleSidebar,
|
||||
hideOnMobile,
|
||||
handleNewChat,
|
||||
@@ -36,48 +39,57 @@ export default function LogoWithText({
|
||||
}) {
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
const enterpriseSettings = combinedSettings?.enterpriseSettings;
|
||||
const useLogoType =
|
||||
!enterpriseSettings?.use_custom_logo &&
|
||||
!enterpriseSettings?.application_name;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${
|
||||
hideOnMobile && "mobile:hidden"
|
||||
} z-[100] ml-2 mt-1 h-8 mb-auto shrink-0 flex gap-x-0 items-center text-xl`}
|
||||
} z-[100] mt-2 h-8 mb-auto shrink-0 flex items-center text-xl`}
|
||||
>
|
||||
{toggleSidebar && page == "chat" ? (
|
||||
<button
|
||||
onClick={() => toggleSidebar()}
|
||||
className="flex gap-x-2 items-center ml-0 desktop:hidden "
|
||||
className="flex gap-x-2 items-center ml-4 desktop:invisible "
|
||||
>
|
||||
{!toggled ? (
|
||||
<Logo className="desktop:hidden -my-2" height={24} width={24} />
|
||||
) : (
|
||||
<LogoComponent
|
||||
show={toggled}
|
||||
enterpriseSettings={enterpriseSettings!}
|
||||
backgroundToggled={toggled}
|
||||
/>
|
||||
)}
|
||||
|
||||
<FiSidebar
|
||||
size={20}
|
||||
className={`text-text-mobile-sidebar ${toggled && "mobile:hidden"}`}
|
||||
className={`${
|
||||
toggled
|
||||
? "text-text-mobile-sidebar-toggled"
|
||||
: "text-text-mobile-sidebar-untoggled"
|
||||
}`}
|
||||
/>
|
||||
{!showArrow && (
|
||||
<Logo className="desktop:hidden -my-2" height={24} width={24} />
|
||||
)}
|
||||
</button>
|
||||
) : (
|
||||
<div className="mr-1 invisible mb-auto h-6 w-6">
|
||||
<Logo height={24} width={24} />
|
||||
lll
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={`${
|
||||
showArrow ? "desktop:invisible" : "invisible"
|
||||
} break-words inline-block w-fit text-text-700 text-xl`}
|
||||
} break-words inline-block w-fit ml-2 text-text-700 text-xl`}
|
||||
>
|
||||
<LogoComponent
|
||||
enterpriseSettings={enterpriseSettings!}
|
||||
backgroundToggled={toggled}
|
||||
/>
|
||||
<div className="max-w-[175px]">
|
||||
{enterpriseSettings && enterpriseSettings.application_name ? (
|
||||
<div className="w-full">
|
||||
<HeaderTitle backgroundToggled={toggled}>
|
||||
{enterpriseSettings.application_name}
|
||||
</HeaderTitle>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<p className="text-xs text-subtle">Powered by Onyx</p>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<HeaderTitle backgroundToggled={toggled}>Onyx</HeaderTitle>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{page == "chat" && !showArrow && (
|
||||
@@ -85,7 +97,7 @@ export default function LogoWithText({
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Link
|
||||
className="my-auto mobile:hidden"
|
||||
className="mb-auto mobile:hidden"
|
||||
href={
|
||||
`/${page}` +
|
||||
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && assistantId
|
||||
@@ -125,14 +137,10 @@ export default function LogoWithText({
|
||||
}}
|
||||
>
|
||||
{!toggled && !combinedSettings?.isMobile ? (
|
||||
<RightToLineIcon className="mobile:hidden text-sidebar-toggle" />
|
||||
<RightToLineIcon className="text-sidebar-toggle" />
|
||||
) : (
|
||||
<LeftToLineIcon className="mobile:hidden text-sidebar-toggle" />
|
||||
<LeftToLineIcon className="text-sidebar-toggle" />
|
||||
)}
|
||||
<FiSidebar
|
||||
size={20}
|
||||
className="hidden mobile:block text-text-mobile-sidebar"
|
||||
/>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
@@ -39,10 +39,7 @@ import Image, { StaticImageData } from "next/image";
|
||||
import jiraSVG from "../../../public/Jira.svg";
|
||||
import confluenceSVG from "../../../public/Confluence.svg";
|
||||
import openAISVG from "../../../public/Openai.svg";
|
||||
import amazonSVG from "../../../public/Amazon.svg";
|
||||
import geminiSVG from "../../../public/Gemini.svg";
|
||||
import metaSVG from "../../../public/Meta.svg";
|
||||
import mistralSVG from "../../../public/Mistral.svg";
|
||||
import openSourceIcon from "../../../public/OpenSource.png";
|
||||
import litellmIcon from "../../../public/LiteLLM.jpg";
|
||||
|
||||
@@ -52,7 +49,6 @@ import asanaIcon from "../../../public/Asana.png";
|
||||
import anthropicSVG from "../../../public/Anthropic.svg";
|
||||
import nomicSVG from "../../../public/nomic.svg";
|
||||
import microsoftIcon from "../../../public/microsoft.png";
|
||||
import microsoftSVG from "../../../public/Microsoft.svg";
|
||||
import mixedBreadSVG from "../../../public/Mixedbread.png";
|
||||
|
||||
import OCIStorageSVG from "../../../public/OCI.svg";
|
||||
@@ -1108,26 +1104,6 @@ export const GeminiIcon = ({
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => <LogoIcon size={size} className={className} src={geminiSVG} />;
|
||||
|
||||
export const AmazonIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => <LogoIcon size={size} className={className} src={amazonSVG} />;
|
||||
|
||||
export const MetaIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => <LogoIcon size={size} className={className} src={metaSVG} />;
|
||||
|
||||
export const MicrosoftIconSVG = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => <LogoIcon size={size} className={className} src={microsoftSVG} />;
|
||||
|
||||
export const MistralIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => <LogoIcon size={size} className={className} src={mistralSVG} />;
|
||||
|
||||
export const VoyageIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
||||
@@ -22,7 +22,8 @@ export const DeleteEntityModal = ({
|
||||
<h2 className="my-auto text-2xl font-bold">Delete {entityType}?</h2>
|
||||
</div>
|
||||
<p className="mb-4">
|
||||
Click below to confirm that you want to delete <b>{entityName}</b>
|
||||
Click below to confirm that you want to delete{" "}
|
||||
<b>"{entityName}"</b>
|
||||
</p>
|
||||
{additionalDetails && <p className="mb-4">{additionalDetails}</p>}
|
||||
<div className="flex">
|
||||
|
||||
@@ -75,6 +75,3 @@ export const REGISTRATION_URL =
|
||||
process.env.INTERNAL_URL || "http://127.0.0.1:3001";
|
||||
|
||||
export const TEST_ENV = process.env.TEST_ENV?.toLowerCase() === "true";
|
||||
|
||||
export const NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED =
|
||||
process.env.NEXT_PUBLIC_DELETE_ALL_CHATS_ENABLED?.toLowerCase() === "true";
|
||||
|
||||
@@ -11,16 +11,12 @@ import { errorHandlingFetcher } from "./fetcher";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||
import { SourceMetadata } from "./search/interfaces";
|
||||
import { destructureValue, structureValue } from "./llm/utils";
|
||||
import { destructureValue } from "./llm/utils";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import { UsersResponse } from "./users/interfaces";
|
||||
import { Credential } from "./connectors/credentials";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
|
||||
import {
|
||||
LLMProvider,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
@@ -161,7 +157,7 @@ export interface LlmOverride {
|
||||
|
||||
export interface LlmOverrideManager {
|
||||
llmOverride: LlmOverride;
|
||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||
globalDefault: LlmOverride;
|
||||
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||
temperature: number | null;
|
||||
@@ -169,64 +165,52 @@ export interface LlmOverrideManager {
|
||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||
}
|
||||
export function useLlmOverride(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
globalModel?: string | null,
|
||||
currentChatSession?: ChatSession,
|
||||
defaultTemperature?: number
|
||||
): LlmOverrideManager {
|
||||
const getValidLlmOverride = (
|
||||
overrideModel: string | null | undefined
|
||||
): LlmOverride => {
|
||||
if (overrideModel) {
|
||||
const model = destructureValue(overrideModel);
|
||||
const provider = llmProviders.find(
|
||||
(p) =>
|
||||
p.model_names.includes(model.modelName) &&
|
||||
p.provider === model.provider
|
||||
);
|
||||
if (provider) {
|
||||
return { ...model, name: provider.name };
|
||||
}
|
||||
}
|
||||
return { name: "", provider: "", modelName: "" };
|
||||
};
|
||||
|
||||
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
|
||||
getValidLlmOverride(globalModel)
|
||||
globalModel != null
|
||||
? destructureValue(globalModel)
|
||||
: {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
}
|
||||
);
|
||||
const updateLLMOverride = (newOverride: LlmOverride) => {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(
|
||||
structureValue(
|
||||
newOverride.name,
|
||||
newOverride.provider,
|
||||
newOverride.modelName
|
||||
)
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
|
||||
currentChatSession && currentChatSession.current_alternate_model
|
||||
? getValidLlmOverride(currentChatSession.current_alternate_model)
|
||||
: { name: "", provider: "", modelName: "" }
|
||||
? destructureValue(currentChatSession.current_alternate_model)
|
||||
: {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
}
|
||||
);
|
||||
|
||||
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
||||
setLlmOverride(
|
||||
chatSession && chatSession.current_alternate_model
|
||||
? getValidLlmOverride(chatSession.current_alternate_model)
|
||||
? destructureValue(chatSession.current_alternate_model)
|
||||
: globalDefault
|
||||
);
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number | null>(
|
||||
defaultTemperature !== undefined ? defaultTemperature : 0
|
||||
defaultTemperature != undefined ? defaultTemperature : 0
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setGlobalDefault(getValidLlmOverride(globalModel));
|
||||
}, [globalModel, llmProviders]);
|
||||
setGlobalDefault(
|
||||
globalModel != null
|
||||
? destructureValue(globalModel)
|
||||
: {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
}
|
||||
);
|
||||
}, [globalModel]);
|
||||
|
||||
useEffect(() => {
|
||||
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
|
||||
@@ -249,7 +233,7 @@ export function useLlmOverride(
|
||||
return {
|
||||
updateModelOverrideForChatSession,
|
||||
llmOverride,
|
||||
updateLLMOverride,
|
||||
setLlmOverride,
|
||||
globalDefault,
|
||||
setGlobalDefault,
|
||||
temperature,
|
||||
@@ -299,7 +283,6 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
|
||||
// OpenAI models
|
||||
"o1-mini": "O1 Mini",
|
||||
"o1-preview": "O1 Preview",
|
||||
"o1-2024-12-17": "O1",
|
||||
"gpt-4": "GPT 4",
|
||||
"gpt-4o": "GPT 4o",
|
||||
"gpt-4o-2024-08-06": "GPT 4o (Structured Outputs)",
|
||||
@@ -319,21 +302,6 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
|
||||
"gpt-3.5-turbo-16k-0613": "GPT 3.5 Turbo 16k (June 2023)",
|
||||
"gpt-3.5-turbo-0301": "GPT 3.5 Turbo (March 2023)",
|
||||
|
||||
// Amazon models
|
||||
"amazon.nova-micro@v1": "Amazon Nova Micro",
|
||||
"amazon.nova-lite@v1": "Amazon Nova Lite",
|
||||
"amazon.nova-pro@v1": "Amazon Nova Pro",
|
||||
|
||||
// Meta models
|
||||
"llama-3.2-90b-vision-instruct": "Llama 3.2 90B",
|
||||
"llama-3.2-11b-vision-instruct": "Llama 3.2 11B",
|
||||
"llama-3.3-70b-instruct": "Llama 3.3 70B",
|
||||
|
||||
// Microsoft models
|
||||
"phi-3.5-mini-instruct": "Phi 3.5 Mini",
|
||||
"phi-3.5-moe-instruct": "Phi 3.5 MoE",
|
||||
"phi-3.5-vision-instruct": "Phi 3.5 Vision",
|
||||
|
||||
// Anthropic models
|
||||
"claude-3-opus-20240229": "Claude 3 Opus",
|
||||
"claude-3-sonnet-20240229": "Claude 3 Sonnet",
|
||||
@@ -345,9 +313,6 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
|
||||
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet (New)",
|
||||
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
|
||||
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
|
||||
"claude-3-5-haiku-20241022": "Claude 3.5 Haiku",
|
||||
"claude-3-5-haiku@20241022": "Claude 3.5 Haiku",
|
||||
"claude-3.5-haiku@20241022": "Claude 3.5 Haiku",
|
||||
|
||||
// Google Models
|
||||
"gemini-1.5-pro": "Gemini 1.5 Pro",
|
||||
@@ -356,11 +321,6 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
|
||||
"gemini-1.5-flash-001": "Gemini 1.5 Flash",
|
||||
"gemini-1.5-pro-002": "Gemini 1.5 Pro (v2)",
|
||||
"gemini-1.5-flash-002": "Gemini 1.5 Flash (v2)",
|
||||
"gemini-2.0-flash-exp": "Gemini 2.0 Flash (Experimental)",
|
||||
|
||||
// Mistral Models
|
||||
"mistral-large-2411": "Mistral Large 24.11",
|
||||
"mistral-large@2411": "Mistral Large 24.11",
|
||||
|
||||
// Bedrock models
|
||||
"meta.llama3-1-70b-instruct-v1:0": "Llama 3.1 70B",
|
||||
|
||||
@@ -74,8 +74,6 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
// custom claude names
|
||||
"claude-3.5-sonnet-v2@20241022",
|
||||
// claude names with AWS Bedrock Suffix
|
||||
"claude-3-opus-20240229-v1:0",
|
||||
"claude-3-sonnet-20240229-v1:0",
|
||||
@@ -95,13 +93,6 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
"gemini-1.5-flash-001",
|
||||
"gemini-1.5-pro-002",
|
||||
"gemini-1.5-flash-002",
|
||||
"gemini-2.0-flash-exp",
|
||||
// amazon models
|
||||
"amazon.nova-lite@v1",
|
||||
"amazon.nova-pro@v1",
|
||||
// meta models
|
||||
"llama-3.2-90b-vision-instruct",
|
||||
"llama-3.2-11b-vision-instruct"
|
||||
];
|
||||
|
||||
export function checkLLMSupportsImageInput(model: string) {
|
||||
|
||||
@@ -143,7 +143,7 @@ module.exports = {
|
||||
// Background for chat messages (user bubbles)
|
||||
user: "var(--user-bubble)",
|
||||
|
||||
"userdropdown-background": "var(--background-150)",
|
||||
"userdropdown-background": "var(--background-100)",
|
||||
"text-mobile-sidebar-toggled": "var(--text-800)",
|
||||
"text-mobile-sidebar-untoggled": "var(--text-500)",
|
||||
"text-editing-message": "var(--text-800)",
|
||||
|
||||
Reference in New Issue
Block a user