Compare commits

..

24 Commits

Author SHA1 Message Date
pablodanswer
09e6bd3c9c k 2024-12-18 20:01:44 -08:00
pablodanswer
c1803cdd56 log 2024-12-18 19:20:55 -08:00
pablodanswer
a5b9c76012 validation 2024-12-18 19:13:09 -08:00
rkuo-danswer
e9b10e8b41 temporarily disabling validate indexing fences (#3502)
* temporarily disabling validate indexing fences

* add back a few startup checks in the cloud

* use common vespa client to perform health check

* log vespa url and try using http1 on light worker index methods

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-19 01:32:09 +00:00
pablonyx
a0fa4adb60 Ensure password validation errors propagate (#3509)
* ensure password validation errors propagate

* copy update

* support o1

* improve typing

* Revert "support o1"

This reverts commit 9b7aa6008c.
2024-12-19 00:05:57 +00:00
pablonyx
ca9ba925bd Support o1 (#3510)
* support o1

* nit
2024-12-19 00:05:00 +00:00
rkuo-danswer
833cc5c97c Merge pull request #3497 from emerzon/new_icons
New model icons for LLM Picker
2024-12-18 16:38:31 -08:00
Chris Weaver
23ecf654ed Add support for custom LLM error messages (#3501)
* Add support for custom LLM error messages

* Fix mypy
2024-12-17 22:58:17 -08:00
pablonyx
ddc6a6d2b3 Wrap nits (#3496) 2024-12-17 18:03:38 -08:00
pablonyx
571c8ece32 Slack Workspace Alembic Updates
Old alembic migration + restore workspace
2024-12-17 16:28:59 -08:00
pablodanswer
884bdb4b01 old alembic migration + restore workspace 2024-12-17 16:28:05 -08:00
pablonyx
b3ecf0d59f Migrate user milestone logic (#3493) 2024-12-17 15:59:56 -08:00
Emerson Gomes
f56fda27c9 Add also Microsoft models 2024-12-17 16:37:52 -06:00
Emerson Gomes
b1e4d4ea8d Adds icons for Amazon, Meta and Mistral models (when proxied via LiteLLM) 2024-12-17 16:20:46 -06:00
pablonyx
8db6d49fe5 IAM Auth for RDS (#3479)
* k

* functional iam auth

* k

* k

* improve typing

* add deployment options

* cleanup

* quick clean up

* minor cleanup

* additional clarity for db session operations

* nit

* k

* k

* update configs

* docker compose spacing
2024-12-17 22:02:37 +00:00
pablonyx
28598694b1 Add delete all chats option (#2515)
* Add delete all chats option

* post rebase fixes

* final validation

* minor cleanup

* move up
2024-12-17 02:55:35 +00:00
Emerson Gomes
b5d0df90b9 Remove hardcoded root path for HF models 2024-12-16 19:03:15 -08:00
pablonyx
48be6338ec Update Hubpost tracking form submission (#3261)
* Update Hubpost tracking form submission

* minor cleanup

* validated

* validate

* nit

* k
2024-12-17 02:31:09 +00:00
pablonyx
ed9014f03d Use logotypes where feasible (#3478)
* Use logotypes where feasible

* quick nit

* minor cleanup
2024-12-17 02:13:45 +00:00
rkuo-danswer
2dd51230ed clear indexing fences with no celery tasks queued (#3482)
* allow beat tasks to expire. it isn't important that they all run

* validate fences are in a good state and cancel/fail them if not

* add function timings for important beat tasks

* optimize lookups, add lots of comments

* review changes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-17 00:55:58 +00:00
pablonyx
8b249cbe63 Proper display priority seeding (#3468)
* proper seeding

* k

* clean up
2024-12-17 00:19:45 +00:00
pablonyx
6b50f86cd2 Improved theming (#3204) 2024-12-16 22:24:32 +00:00
pablonyx
bd2805b6df Update llm override defaults (#3230)
* update llm override defaults

* post rebase fix
2024-12-16 22:18:21 +00:00
pablonyx
2847ab003e Prompting (#3372)
* auto generate start prompts

* post rebase clean up

* update for clarity
2024-12-16 21:34:43 +00:00
91 changed files with 2059 additions and 591 deletions

View File

@@ -1,39 +1,49 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine.base import Connection
from typing import Literal
import os
import ssl
import asyncio
from logging.config import fileConfig
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from onyx.db.engine import build_connection_string
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from onyx.db.engine import get_all_tenant_ids
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
if config.config_file_name is not None and config.attributes.get(
"configure_logger", True
):
fileConfig(config.config_file_name)
# Add your model's MetaData object here for 'autogenerate' support
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
# Set up logging
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
if USE_IAM_AUTH:
if not os.path.exists(SSL_CERT_FILE):
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
def include_object(
object: SchemaItem,
@@ -49,20 +59,12 @@ def include_object(
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
# Set search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema_name}"'))
context.configure(
@@ -117,11 +115,25 @@ def do_run_migrations(
context.run_migrations()
def provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
# Get IAM authentication token
token = get_iam_auth_token(host, port, user, region)
# For Alembic / SQLAlchemy in this context, set SSL and password
cparams["password"] = token
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
@@ -129,10 +141,16 @@ async def run_async_migrations() -> None:
poolclass=pool.NullPool,
)
if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
@@ -162,15 +180,20 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
@@ -207,9 +230,6 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations())

View File

@@ -0,0 +1,87 @@
"""delete workspace
Revision ID: c0aab6edb6dd
Revises: 35e518e0ddf4
Create Date: 2024-12-17 14:37:07.660631
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "c0aab6edb6dd"
down_revision = "35e518e0ddf4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = connector_specific_config - 'workspace'
WHERE source = 'SLACK'
"""
)
def downgrade() -> None:
import json
from sqlalchemy import text
from slack_sdk import WebClient
conn = op.get_bind()
# Fetch all Slack credentials
creds_result = conn.execute(
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
)
all_slack_creds = creds_result.fetchall()
if not all_slack_creds:
return
for cred_row in all_slack_creds:
credential_id, credential_json = cred_row
credential_json = (
credential_json.tobytes().decode("utf-8")
if isinstance(credential_json, memoryview)
else credential_json.decode("utf-8")
)
credential_data = json.loads(credential_json)
slack_bot_token = credential_data.get("slack_bot_token")
if not slack_bot_token:
print(
f"No slack_bot_token found for credential {credential_id}. "
"Your Slack connector will not function until you upgrade and provide a valid token."
)
continue
client = WebClient(token=slack_bot_token)
try:
auth_response = client.auth_test()
workspace = auth_response["url"].split("//")[1].split(".")[0]
# Update only the connectors linked to this credential
# (and which are Slack connectors).
op.execute(
f"""
UPDATE connector AS c
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{{workspace}}',
to_jsonb('{workspace}'::text)
)
FROM connector_credential_pair AS ccp
WHERE ccp.connector_id = c.id
AND c.source = 'SLACK'
AND ccp.credential_id = {credential_id}
"""
)
except Exception:
print(
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
)
print("This connector will no longer work until you upgrade.")
continue

View File

@@ -53,3 +53,5 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")

View File

@@ -3,12 +3,15 @@ import logging
import uuid
import aiohttp # Async HTTP client
import httpx
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantCreationPayload
@@ -47,13 +50,16 @@ from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
@@ -281,3 +287,36 @@ def configure_default_api_keys(db_session: Session) -> None:
logger.info(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)
async def submit_to_hubspot(
email: str, referral_source: str | None, request: Request
) -> None:
if not HUBSPOT_TRACKING_URL:
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
return
# HubSpot tracking cookie
hubspot_cookie = request.cookies.get("hubspotutk")
# IP address
ip_address = request.client.host if request.client else None
data = {
"fields": [
{"name": "email", "value": email},
{"name": "referral_source", "value": referral_source or ""},
],
"context": {
"hutk": hubspot_cookie,
"ipAddress": ip_address,
"pageUri": str(request.url),
"pageName": "User Registration",
},
}
async with httpx.AsyncClient() as client:
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")

View File

@@ -1,14 +1,38 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
posthog = Posthog(project_api_key=POSTHOG_API_KEY, host=POSTHOG_HOST)
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def event_telemetry(
distinct_id: str,
event: str,
properties: dict | None = None,
distinct_id: str, event: str, properties: dict | None = None
) -> None:
posthog.capture(distinct_id, event, properties)
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
print("API KEY", POSTHOG_API_KEY)
print("HOST", POSTHOG_HOST)
try:
print(type(distinct_id))
print(type(event))
print(type(properties))
response = posthog.capture(distinct_id, event, properties)
posthog.flush()
print(response)
except Exception as e:
logger.error(f"Error capturing Posthog event: {e}")

View File

@@ -27,8 +27,8 @@ from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path("/root/.cache/huggingface/")
TEMP_HF_CACHE_PATH = Path("/root/.cache/temp_huggingface/")
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
transformer_logging.set_verbosity_error()

View File

@@ -5,6 +5,7 @@ from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
@@ -228,18 +229,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
user_create.password, cast(schemas.UC, user_create)
)
user_count: int | None = None
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
referral_source = (
request.cookies.get("referral_source", None)
if request is not None
else None
)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
request=request,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -282,25 +291,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# Blocking but this should be very quick
with get_session_with_tenant(tenant_id) as db_session:
if not user_count:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.USER_SIGNED_UP,
properties=None,
db_session=db_session,
)
else:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.MULTIPLE_USERS,
properties=None,
db_session=db_session,
)
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
@@ -346,17 +336,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
referral_source = (
getattr(request.state, "referral_source", None) if request else None
)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
request=request,
)
if not tenant_id:
@@ -418,6 +409,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
@@ -471,6 +463,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,
request=request,
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if user_count == 1:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.USER_SIGNED_UP,
properties=None,
db_session=db_session,
)
else:
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=MilestoneRecordType.MULTIPLE_USERS,
properties=None,
db_session=db_session,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.notice(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
@@ -502,7 +527,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
@@ -563,7 +588,7 @@ class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_create_tenant_id",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=user.email,

View File

@@ -3,7 +3,6 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -23,6 +22,7 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
@@ -262,7 +262,8 @@ def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
client = get_vespa_http_client()
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()

View File

@@ -13,7 +13,6 @@ from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -154,10 +153,6 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -61,13 +61,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -62,13 +62,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -60,13 +60,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -84,14 +84,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once

View File

@@ -1,4 +1,6 @@
# These are helper objects for tracking the keys we need to write in redis
import json
from typing import Any
from typing import cast
from redis import Redis
@@ -23,3 +25,25 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
total_length += cast(int, length)
return total_length
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
"""This is a redis specific way to find a task for a particular queue in redis.
It is priority aware and knows how to look through the multiple redis lists
used to implement task prioritization.
This operation is not atomic.
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
Returns true if the id is in the queue, False if not.
"""
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
if task_dict.get("headers", {}).get("id") == task_id:
return True
return False

View File

@@ -4,55 +4,80 @@ from typing import Any
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": OnyxCeleryPriority.LOWEST},
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": 60,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
]

View File

@@ -1,7 +1,9 @@
import time
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import Any
import redis
import sentry_sdk
@@ -15,6 +17,7 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
@@ -162,11 +165,19 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
redis_client = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -271,7 +282,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings_instance,
reindex,
db_session,
r,
redis_client,
tenant_id,
)
if attempt_id:
@@ -286,7 +297,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
@@ -304,6 +317,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
# turning off for the moment to see if it helps cloud stability
# we want to run this less frequently than the overall task
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# task_logger.info("Validating indexing fences...")
# validate_indexing_fences(
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
# )
# except Exception:
# task_logger.exception("Exception while validating indexing fences")
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -320,9 +352,190 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"tenant={tenant_id}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
def validate_indexing_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks: set[str] = set()
active_indexing_tasks: set[str] = set()
indexing_worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = celery_app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if not workers:
raise ValueError("No workers found!")
for worker_name in list(workers.keys()):
if "indexing" in worker_name:
indexing_worker_names.append(worker_name)
if len(indexing_worker_names) == 0:
raise ValueError("No indexing workers found!")
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
# NOTE: each dict entry is a map of worker name to a list of tasks
# we want sets for reserved task and active task id's to optimize
# subsequent validation lookups
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
if reserved_tasks is None:
raise ValueError("inspect_indexing.reserved() returned None!")
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_indexing_tasks.add(task["id"])
# get the list of active tasks
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
if active_tasks is None:
raise ValueError("inspect_indexing.active() returned None!")
for _, task_list in active_tasks.items():
for task in task_list:
active_indexing_tasks.add(task["id"])
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
active_indexing_tasks,
r_celery,
db_session,
)
return
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
active_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. Active signal is renewed with a 5 minute TTL
1.1 When the fence is created
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved or active list for a worker
2. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
if payload.celery_task_id in active_tasks:
# the celery task is active (aka currently executing)
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
# due to gaps in our ability to check states during transitions
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed."
)
redis_connector_index.reset()
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
@@ -469,6 +682,7 @@ def try_creating_indexing_task(
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
@@ -502,6 +716,8 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
@@ -642,7 +858,7 @@ def connector_indexing_proxy_task(
if job.process:
exit_code = job.process.exitcode
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
@@ -872,6 +1088,7 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
index_attempt_id,
tenant_id,

View File

@@ -1,3 +1,4 @@
import time
import traceback
from datetime import datetime
from datetime import timezone
@@ -89,10 +90,11 @@ logger = setup_logger()
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -161,6 +163,10 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
@@ -730,6 +736,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -824,6 +831,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True

View File

@@ -1,6 +1,7 @@
import json
import os
import urllib.parse
from typing import cast
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
@@ -144,6 +145,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
@@ -174,6 +176,9 @@ try:
except ValueError:
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
@@ -483,6 +488,21 @@ SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
# allow for custom error messages for different errors returned by litellm
# for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"}
# to make it so that if an LLM call returns an error containing "Violated content safety policy"
# the end user will see "EVIL REQUEST!!!" instead of the default error message.
_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get(
"LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", ""
)
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None
try:
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast(
dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS)
)
except json.JSONDecodeError:
pass
#####
# Enterprise Edition Configs
#####

View File

@@ -63,6 +63,10 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Number of prompts each persona should have
NUM_PERSONA_PROMPTS = 4
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.

View File

@@ -49,6 +49,7 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
SSL_CERT_FILE = "bundle.pem"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
@@ -274,6 +275,10 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
class OnyxCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()

View File

@@ -1,5 +1,7 @@
import contextlib
import os
import re
import ssl
import threading
import time
from collections.abc import AsyncGenerator
@@ -10,6 +12,8 @@ from datetime import datetime
from typing import Any
from typing import ContextManager
import asyncpg # type: ignore
import boto3
import jwt
from fastapi import HTTPException
from fastapi import Request
@@ -23,6 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -37,6 +42,7 @@ from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -49,28 +55,87 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
# Global so we don't create more than one engine per process
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()
def get_iam_auth_token(
host: str, port: str, user: str, region: str = "us-east-2"
) -> str:
"""
Generate an IAM authentication token using boto3.
"""
client = boto3.client("rds", region_name=region)
token = client.generate_db_auth_token(
DBHostname=host, Port=int(port), DBUsername=user
)
return token
def configure_psycopg2_iam_auth(
cparams: dict[str, Any], host: str, port: str, user: str, region: str
) -> None:
"""
Configure cparams for psycopg2 with IAM token and SSL.
"""
token = get_iam_auth_token(host, port, user, region)
cparams["password"] = token
cparams["sslmode"] = "require"
cparams["sslrootcert"] = SSL_CERT_FILE
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
use_iam: bool = USE_IAM_AUTH,
region: str = "us-west-2",
) -> str:
if use_iam:
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
else:
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
# For asyncpg, do not include application_name in the connection string
if app_name and db_api != "asyncpg":
if "?" in base_conn_str:
return f"{base_conn_str}&application_name={app_name}"
else:
return f"{base_conn_str}?application_name={app_name}"
return base_conn_str
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
conn.info["query_start_time"] = time.time()
# Function to log after query execution
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
total_time = time.time() - conn.info["query_start_time"]
# don't spam TOO hard
if total_time > 0.1:
logger.debug(
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
@@ -78,7 +143,6 @@ if LOG_POSTGRES_LATENCY:
if LOG_POSTGRES_CONN_COUNTS:
# Global counter for connection checkouts and checkins
checkout_count = 0
checkin_count = 0
@@ -105,21 +169,13 @@ if LOG_POSTGRES_CONN_COUNTS:
logger.debug(f"Total connection checkins: {checkin_count}")
"""END DEBUGGING LOGGING"""
def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
Within the same transaction this value will not update
This datetime object returned should be timezone aware, default Postgres timezone is UTC
"""
result = db_session.execute(text("SELECT NOW()")).scalar()
if result is None:
raise ValueError("Database did not return a time")
return result
# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
@@ -128,16 +184,9 @@ def is_valid_schema_name(name: str) -> bool:
class SqlEngine:
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
Will eventually subsume most of the standalone functions in this file.
Sync only for now.
"""
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
# Default parameters for engine creation
DEFAULT_ENGINE_KWARGS = {
"pool_size": 20,
"max_overflow": 5,
@@ -145,33 +194,27 @@ class SqlEngine:
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
def __init__(self) -> None:
pass
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
"""Private helper method to create and return an Engine."""
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
return create_engine(connection_string, **merged_kwargs)
engine = create_engine(connection_string, **merged_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
return engine
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
"""Allow the caller to init the engine with extra params. Different clients
such as the API server and different Celery workers and tasks
need different settings.
"""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
@classmethod
def get_engine(cls) -> Engine:
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
already been called. You probably want to init first!
"""
if not cls._engine:
with cls._lock:
if not cls._engine:
@@ -180,12 +223,10 @@ class SqlEngine:
@classmethod
def set_app_name(cls, app_name: str) -> None:
"""Class method to set the app name."""
cls._app_name = app_name
@classmethod
def get_app_name(cls) -> str:
"""Class method to get current app name."""
if not cls._app_name:
return ""
return cls._app_name
@@ -217,56 +258,71 @@ def get_all_tenant_ids() -> list[str] | list[None]:
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
user: str = POSTGRES_USER,
password: str = POSTGRES_PASSWORD,
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
"""
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
)
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# Underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
app_name = SqlEngine.get_app_name() + "_async"
connection_string = build_connection_string(
db_api=ASYNC_DB_API,
use_iam=USE_IAM_AUTH,
)
connect_args: dict[str, Any] = {}
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args={
"server_settings": {
"application_name": SqlEngine.get_app_name() + "_async"
}
},
# async engine is only used by API server, so we can use those values
# here as well
connect_args=connect_args,
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
if USE_IAM_AUTH:
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
def provide_iam_token_async(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
# For async engine using asyncpg, we still need to set the IAM token here.
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION)
cparams["password"] = token
cparams["ssl"] = ssl_context
return _ASYNC_ENGINE
# Dependency to get the current tenant ID
# If no token is present, uses the default schema for this use case
def get_current_tenant_id(request: Request) -> str:
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -275,7 +331,6 @@ def get_current_tenant_id(request: Request) -> str:
token = request.cookies.get("fastapiusersauth")
if not token:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no token is present, use the default schema or handle accordingly
return current_value
try:
@@ -289,7 +344,6 @@ def get_current_tenant_id(request: Request) -> str:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except jwt.InvalidTokenError:
return CURRENT_TENANT_ID_CONTEXTVAR.get()
@@ -316,7 +370,6 @@ async def get_async_session_with_tenant(
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
@@ -326,8 +379,6 @@ async def get_async_session_with_tenant(
)
except Exception:
logger.exception("Error setting search_path.")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
yield session
@@ -335,9 +386,6 @@ async def get_async_session_with_tenant(
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@@ -349,7 +397,6 @@ def get_session_with_tenant(
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
@@ -357,27 +404,20 @@ def get_session_with_tenant(
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
@@ -390,21 +430,17 @@ def get_session_with_tenant(
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
@@ -424,12 +460,9 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(
detail="User must authenticate",
)
raise BasicAuthenticationError(detail="User must authenticate")
engine = get_sqlalchemy_engine()
@@ -437,20 +470,17 @@ def get_session() -> Generator[Session, None, None]:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Generate an async database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
yield async_session
@@ -461,7 +491,6 @@ def get_session_context_manager() -> ContextManager[Session]:
def get_session_factory() -> sessionmaker[Session]:
"""Get a session factory."""
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
@@ -489,3 +518,13 @@ async def warm_up_connections(
await async_conn.execute(text("SELECT 1"))
for async_conn in async_connections:
await async_conn.close()
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
if USE_IAM_AUTH:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)

View File

@@ -5,6 +5,8 @@ from typing import Literal
from typing import NotRequired
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel
from typing_extensions import TypedDict # noreorder
from uuid import UUID
@@ -1346,6 +1348,11 @@ class StarterMessage(TypedDict):
message: str
class StarterMessageModel(BaseModel):
name: str
message: str
class Persona(Base):
__tablename__ = "persona"

View File

@@ -543,6 +543,10 @@ def upsert_persona(
if tools is not None:
existing_persona.tools = tools or []
# We should only update display priority if it is not already set
if existing_persona.display_priority is None:
existing_persona.display_priority = display_priority
persona = existing_persona
else:

View File

@@ -369,6 +369,19 @@ class AdminCapable(abc.ABC):
raise NotImplementedError
class RandomCapable(abc.ABC):
"""Class must implement random document retrieval capability"""
@abc.abstractmethod
def random_retrieval(
self,
filters: IndexFilters,
num_to_retrieve: int = 10,
) -> list[InferenceChunkUncleaned]:
"""Retrieve random chunks matching the filters"""
raise NotImplementedError
class BaseIndex(
Verifiable,
Indexable,
@@ -376,6 +389,7 @@ class BaseIndex(
Deletable,
AdminCapable,
IdRetrievalCapable,
RandomCapable,
abc.ABC,
):
"""

View File

@@ -218,4 +218,10 @@ schema DANSWER_CHUNK_NAME {
expression: bm25(content) + (5 * bm25(title))
}
}
rank-profile random_ {
first-phase {
expression: random.match
}
}
}

View File

@@ -2,6 +2,7 @@ import concurrent.futures
import io
import logging
import os
import random
import re
import time
import urllib
@@ -534,7 +535,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client() as http_client:
with get_vespa_http_client(http2=False) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -545,8 +546,12 @@ class VespaIndex(DocumentIndex):
while True:
try:
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}"
)
logger.debug(f'update_single PUT on URL "{vespa_url}"')
resp = http_client.put(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
vespa_url,
params=params,
headers={"Content-Type": "application/json"},
json=update_dict,
@@ -618,7 +623,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client() as http_client:
with get_vespa_http_client(http2=False) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -629,8 +634,12 @@ class VespaIndex(DocumentIndex):
while True:
try:
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}"
)
logger.debug(f'delete_single DELETE on URL "{vespa_url}"')
resp = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
vespa_url,
params=params,
)
resp.raise_for_status()
@@ -903,6 +912,32 @@ class VespaIndex(DocumentIndex):
logger.info("Batch deletion completed")
def random_retrieval(
self,
filters: IndexFilters,
num_to_retrieve: int = 10,
) -> list[InferenceChunkUncleaned]:
"""Retrieve random chunks matching the filters using Vespa's random ranking
This method is currently used for random chunk retrieval in the context of
assistant starter message creation (passed as sample context for usage by the assistant).
"""
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
random_seed = random.randint(0, 1000000)
params: dict[str, str | int | float] = {
"yql": yql,
"hits": num_to_retrieve,
"timeout": VESPA_TIMEOUT,
"ranking.profile": "random_",
"ranking.properties.random.seed": random_seed,
}
return query_vespa(params)
class _VespaDeleteRequest:
def __init__(self, document_id: str, index_name: str) -> None:

View File

@@ -55,7 +55,9 @@ def remove_invalid_unicode_chars(text: str) -> str:
return _illegal_xml_chars_RE.sub("", text)
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
def get_vespa_http_client(
no_timeout: bool = False, http2: bool = False
) -> httpx.Client:
"""
Configure and return an HTTP client for communicating with Vespa,
including authentication if needed.
@@ -67,5 +69,5 @@ def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
else None,
verify=False if not MANAGED_VESPA else True,
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=True,
http2=http2,
)

View File

@@ -19,7 +19,12 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
def build_vespa_filters(
filters: IndexFilters,
*,
include_hidden: bool = False,
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
if vals is None:
return ""
@@ -78,6 +83,9 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
filter_str += _build_time_filter(filters.time_cutoff)
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5] # We remove the trailing " and "
return filter_str

View File

@@ -453,7 +453,9 @@ class DefaultMultiLLM(LLM):
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
if (
DISABLE_LITELLM_STREAMING or self.config.model_name == "o1-2024-12-17"
): # TODO: remove once litellm supports streaming
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return

View File

@@ -29,6 +29,7 @@ OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",

View File

@@ -28,6 +28,7 @@ from litellm.exceptions import RateLimitError # type: ignore
from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -45,10 +46,19 @@ logger = setup_logger()
def litellm_exception_to_error_msg(
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
e: Exception,
llm: LLM,
fallback_to_error_msg: bool = False,
custom_error_msg_mappings: dict[str, str]
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
) -> str:
error_msg = str(e)
if custom_error_msg_mappings:
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
if error_msg_pattern in error_msg:
return custom_error_msg
if isinstance(e, BadRequestError):
error_msg = "Bad request: The server couldn't process your request. Please check your input."
elif isinstance(e, AuthenticationError):

View File

@@ -0,0 +1,46 @@
PERSONA_CATEGORY_GENERATION_PROMPT = """
Based on the assistant's name, description, and instructions, generate a list of {num_categories}
**unique and diverse** categories that represent different types of starter messages a user
might send to initiate a conversation with this chatbot assistant.
**Ensure that the categories are varied and cover a wide range of topics related to the assistant's capabilities.**
Provide the categories as a JSON array of strings **without any code fences or additional text**.
**Context about the assistant:**
- **Name**: {name}
- **Description**: {description}
- **Instructions**: {instructions}
""".strip()
PERSONA_STARTER_MESSAGE_CREATION_PROMPT = """
Create a starter message that a **user** might send to initiate a conversation with a chatbot assistant.
**Category**: {category}
Your response should include two parts:
1. **Title**: A short, engaging title that reflects the user's intent
(e.g., 'Need Travel Advice', 'Question About Coding', 'Looking for Book Recommendations').
2. **Message**: The actual message that the user would send to the assistant.
This should be natural, engaging, and encourage a helpful response from the assistant.
**Avoid overly specific details; keep the message general and broadly applicable.**
For example:
- Instead of "I've just adopted a 6-month-old Labrador puppy who's pulling on the leash,"
write "I'm having trouble training my new puppy to walk nicely on a leash."
Ensure each part is clearly labeled and separated as shown above.
Do not provide any additional text or explanation and be extremely concise
**Context about the assistant:**
- **Name**: {name}
- **Description**: {description}
- **Instructions**: {instructions}
""".strip()
if __name__ == "__main__":
print(PERSONA_CATEGORY_GENERATION_PROMPT)
print(PERSONA_STARTER_MESSAGE_CREATION_PROMPT)

View File

@@ -31,6 +31,10 @@ class RedisConnectorIndex:
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
# used to signal the overall workflow is still active
# it's difficult to prevent
ACTIVE_PREFIX = PREFIX + "_active"
def __init__(
self,
tenant_id: str | None,
@@ -54,6 +58,7 @@ class RedisConnectorIndex:
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -107,6 +112,26 @@ class RedisConnectorIndex:
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=300)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
@@ -138,6 +163,7 @@ class RedisConnectorIndex:
return status
def reset(self) -> None:
self.redis.delete(self.active_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)

View File

@@ -0,0 +1,271 @@
import json
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from litellm import get_supported_openai_params
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
from onyx.configs.chat_configs import NUM_PERSONA_PROMPTS
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.db.document_set import get_document_sets_by_ids
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llms
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
from onyx.prompts.starter_messages import PERSONA_STARTER_MESSAGE_CREATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
logger = setup_logger()
def get_random_chunks_from_doc_sets(
doc_sets: List[str], db_session: Session, user: User | None = None
) -> List[InferenceChunk]:
"""
Retrieves random chunks from the specified document sets.
"""
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
acl_filters = build_access_filters_for_user(user, db_session)
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)
chunks = document_index.random_retrieval(
filters=filters, num_to_retrieve=NUM_PERSONA_PROMPT_GENERATION_CHUNKS
)
return cleanup_chunks(chunks)
def parse_categories(content: str) -> List[str]:
"""
Parses the JSON array of categories from the LLM response.
"""
# Clean the response to remove code fences and extra whitespace
content = content.strip().strip("```").strip()
if content.startswith("json"):
content = content[4:].strip()
try:
categories = json.loads(content)
if not isinstance(categories, list):
logger.error("Categories are not a list.")
return []
return categories
except json.JSONDecodeError as e:
logger.error(f"Failed to parse categories: {e}")
return []
def generate_start_message_prompts(
name: str,
description: str,
instructions: str,
categories: List[str],
chunk_contents: str,
supports_structured_output: bool,
fast_llm: Any,
) -> List[FunctionCall]:
"""
Generates the list of FunctionCall objects for starter message generation.
"""
functions = []
for category in categories:
# Create a prompt specific to the category
start_message_generation_prompt = (
PERSONA_STARTER_MESSAGE_CREATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
category=category,
)
)
if chunk_contents:
start_message_generation_prompt += (
"\n\nExample content this assistant has access to:\n"
"'''\n"
f"{chunk_contents}"
"\n'''"
)
if supports_structured_output:
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt, None, None, StarterMessage),
)
)
else:
functions.append(
FunctionCall(
fast_llm.invoke,
(start_message_generation_prompt,),
)
)
return functions
def parse_unstructured_output(output: str) -> Dict[str, str]:
"""
Parses the assistant's unstructured output into a dictionary with keys:
- 'name' (Title)
- 'message' (Message)
"""
# Debug output
logger.debug(f"LLM Output for starter message creation: {output}")
# Patterns to match
title_pattern = r"(?i)^\**Title\**\s*:\s*(.+)"
message_pattern = r"(?i)^\**Message\**\s*:\s*(.+)"
# Initialize the response dictionary
response_dict = {}
# Split the output into lines
lines = output.strip().split("\n")
# Variables to keep track of the current key being processed
current_key = None
current_value_lines = []
for line in lines:
# Check for title
title_match = re.match(title_pattern, line.strip())
if title_match:
# Save previous key-value pair if any
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
current_value_lines = []
current_key = "name"
current_value_lines.append(title_match.group(1).strip())
continue
# Check for message
message_match = re.match(message_pattern, line.strip())
if message_match:
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
current_value_lines = []
current_key = "message"
current_value_lines.append(message_match.group(1).strip())
continue
# If the line doesn't match a new key, append it to the current value
if current_key:
current_value_lines.append(line.strip())
# Add the last key-value pair
if current_key and current_value_lines:
response_dict[current_key] = " ".join(current_value_lines).strip()
# Validate that the necessary keys are present
if not all(k in response_dict for k in ["name", "message"]):
raise ValueError("Failed to parse the assistant's response.")
return response_dict
def generate_starter_messages(
name: str,
description: str,
instructions: str,
document_set_ids: List[int],
db_session: Session,
user: User | None,
) -> List[StarterMessage]:
"""
Generates starter messages by first obtaining categories and then generating messages for each category.
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
"""
_, fast_llm = get_default_llms(temperature=0.5)
provider = fast_llm.config.model_provider
model = fast_llm.config.model_name
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
supports_structured_output = (
isinstance(params, list) and "response_format" in params
)
# Generate categories
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
name=name,
description=description,
instructions=instructions,
num_categories=NUM_PERSONA_PROMPTS,
)
category_response = fast_llm.invoke(category_generation_prompt)
categories = parse_categories(cast(str, category_response.content))
if not categories:
logger.error("No categories were generated.")
return []
# Fetch example content if document sets are provided
if document_set_ids:
document_sets = get_document_sets_by_ids(
document_set_ids=document_set_ids,
db_session=db_session,
)
chunks = get_random_chunks_from_doc_sets(
doc_sets=[doc_set.name for doc_set in document_sets],
db_session=db_session,
user=user,
)
# Add example content context
chunk_contents = "\n".join(chunk.content.strip() for chunk in chunks)
else:
chunk_contents = ""
# Generate prompts for starter messages
functions = generate_start_message_prompts(
name,
description,
instructions,
categories,
chunk_contents,
supports_structured_output,
fast_llm,
)
# Run LLM calls in parallel
if not functions:
logger.error("No functions to execute for starter message generation.")
return []
results = run_functions_in_parallel(function_calls=functions)
prompts = []
for response in results.values():
try:
if supports_structured_output:
response_dict = json.loads(response.content)
else:
response_dict = parse_unstructured_output(response.content)
starter_message = StarterMessage(
name=response_dict["name"],
message=response_dict["message"],
)
prompts.append(starter_message)
except (json.JSONDecodeError, ValueError) as e:
logger.error(f"Failed to parse starter message: {e}")
continue
return prompts

View File

@@ -48,6 +48,7 @@ def load_personas_from_yaml(
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
@@ -127,6 +128,7 @@ def load_personas_from_yaml(
display_priority=(
existing_persona.display_priority
if existing_persona is not None
and persona.get("display_priority") is None
else persona.get("display_priority")
),
is_visible=(

View File

@@ -19,6 +19,7 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.db.notification import create_notification
from onyx.db.persona import create_assistant_category
@@ -36,7 +37,11 @@ from onyx.db.persona import update_persona_shared_users
from onyx.db.persona import update_persona_visibility
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaCategoryCreate
from onyx.server.features.persona.models import PersonaCategoryResponse
@@ -377,3 +382,26 @@ def build_final_template_prompt(
retrieval_disabled=retrieval_disabled,
)
)
@basic_router.post("/assistant-prompt-refresh")
def build_assistant_prompts(
generate_persona_prompt_request: GenerateStarterMessageRequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> list[StarterMessage]:
try:
logger.info(
"Generating starter messages for user: %s", user.id if user else "Anonymous"
)
return generate_starter_messages(
name=generate_persona_prompt_request.name,
description=generate_persona_prompt_request.description,
instructions=generate_persona_prompt_request.instructions,
document_set_ids=generate_persona_prompt_request.document_set_ids,
db_session=db_session,
user=user,
)
except Exception as e:
logger.exception("Failed to generate starter messages")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -17,6 +17,14 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# More minimal request for generating a persona prompt
class GenerateStarterMessageRequest(BaseModel):
name: str
description: str
instructions: str
document_set_ids: list[int]
class CreatePersonaRequest(BaseModel):
name: str
description: str

View File

@@ -22,6 +22,7 @@ from onyx.utils.variable_functionality import (
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.configs import MULTI_TENANT
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
_CACHED_UUID: str | None = None
_CACHED_INSTANCE_DOMAIN: str | None = None
@@ -117,9 +118,12 @@ def mt_cloud_telemetry(
event: MilestoneRecordType,
properties: dict | None = None,
) -> None:
print(f"mt_cloud_telemetry {distinct_id} {event} {properties}")
if not MULTI_TENANT:
print("mt_cloud_telemetry not MULTI_TENANT")
return
print("mt_cloud_telemetry MULTI_TENANT")
# MIT version should not need to include any Posthog code
# This is only for Onyx MT Cloud, this code should also never be hit, no reason for any orgs to
# be running the Multi Tenant version of Onyx.
@@ -137,8 +141,11 @@ def create_milestone_and_report(
properties: dict | None,
db_session: Session,
) -> None:
print(f"create_milestone_and_report {user} {event_type} {db_session}")
_, is_new = create_milestone_if_not_exists(user, event_type, db_session)
print(f"create_milestone_and_report {is_new}")
if is_new:
print("create_milestone_and_report is_new")
mt_cloud_telemetry(
distinct_id=distinct_id,
event=event_type,

View File

@@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.54.1
litellm==1.55.4
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45

View File

@@ -12,5 +12,5 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.54.1
litellm==1.55.4
sentry-sdk[fastapi,celery,starlette]==2.14.0

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-beat
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-heavy
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-indexing
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-light
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-primary
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
imagePullPolicy: IfNotPresent
command:
[

View File

@@ -92,6 +92,7 @@ services:
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
@@ -103,6 +104,13 @@ services:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
# Seeding configuration
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -223,6 +231,13 @@ services:
# Enterprise Edition stuff
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:

View File

@@ -84,6 +84,7 @@ services:
# (time spent on finding the right docs + time spent fetching summaries from disk)
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
# Chat Configs
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
@@ -91,6 +92,13 @@ services:
# Enterprise Edition only
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -192,6 +200,13 @@ services:
# Enterprise Edition only
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:

View File

@@ -22,6 +22,13 @@ services:
- VESPA_HOST=index
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -52,6 +59,13 @@ services:
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:

View File

@@ -23,6 +23,13 @@ services:
- VESPA_HOST=index
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -57,6 +64,13 @@ services:
- REDIS_HOST=cache
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- USE_IAM_AUTH=${USE_IAM_AUTH}
- AWS_REGION=${AWS_REGION-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@@ -223,7 +237,7 @@ services:
volumes:
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
logging:
logging::wq
driver: json-file
options:
max-size: "50m"
@@ -245,3 +259,6 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:

View File

@@ -60,3 +60,12 @@ spec:
envFrom:
- configMapRef:
name: env-configmap
# Uncomment if you are using IAM auth for Postgres
# volumeMounts:
# - name: bundle-pem
# mountPath: "/app/certs"
# readOnly: true
# volumes:
# - name: bundle-pem
# secret:
# secretName: bundle-pem-secret

View File

@@ -43,6 +43,7 @@ spec:
# - name: my-ca-cert-volume
# mountPath: /etc/ssl/certs/custom-ca.crt
# subPath: my-ca.crt
# Optional volume for CA certificate
# volumes:
# - name: my-cas-cert-volume
@@ -51,3 +52,13 @@ spec:
# items:
# - key: my-ca.crt
# path: my-ca.crt
# Uncomment if you are using IAM auth for Postgres
# volumeMounts:
# - name: bundle-pem
# mountPath: "/app/certs"
# readOnly: true
# volumes:
# - name: bundle-pem
# secret:
# secretName: bundle-pem-secret

1
web/public/Amazon.svg Executable file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 7.0 KiB

9
web/public/Meta.svg Executable file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 340 KiB

6
web/public/Microsoft.svg Executable file
View File

@@ -0,0 +1,6 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect x="1.33325" y="1.3335" width="6.33333" height="6.33333" fill="#F25022"/>
<rect x="8.33325" y="1.3335" width="6.33333" height="6.33333" fill="#80BA01"/>
<rect x="8.33325" y="8.3335" width="6.33333" height="6.33333" fill="#FFB902"/>
<rect x="1.33325" y="8.3335" width="6.33333" height="6.33333" fill="#02A4EF"/>
</svg>

After

Width:  |  Height:  |  Size: 425 B

1
web/public/Mistral.svg Executable file
View File

@@ -0,0 +1 @@
<svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" fill-rule="evenodd" clip-rule="evenodd" stroke-linejoin="round" stroke-miterlimit="2"><path d="M189.08 303.228H94.587l.044-94.446h94.497l-.048 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M283.528 397.674h-94.493l.044-94.446h94.496l-.047 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M283.575 303.228H189.08l.046-94.446h94.496l-.047 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M378.07 303.228h-94.495l.044-94.446h94.498l-.047 94.446zM189.128 208.779H94.633l.044-94.448h94.498l-.047 94.448zM378.115 208.779h-94.494l.045-94.448h94.496l-.047 94.448zM94.587 303.227H.093l.044-96.017h94.496l-.046 96.017z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M94.633 208.779H.138l.046-94.448H94.68l-.047 94.448z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M94.68 115.902H.185L.23 19.885h94.498l-.047 96.017zM472.657 114.331h-94.495l.044-94.446h94.497l-.046 94.446zM94.54 399.244H.046l.044-97.588h94.497l-.047 97.588z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M94.495 492.123H0l.044-94.446H94.54l-.045 94.446zM472.563 303.228H378.07l.044-94.446h94.496l-.047 94.446zM472.61 208.779h-94.495l.044-94.448h94.498l-.047 94.448z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M472.517 397.674h-94.494l.044-94.446h94.497l-.047 94.446z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M472.47 492.121h-94.493l.044-96.017h94.496l-.047 96.017z" fill="#1c1c1b" fill-rule="nonzero"/><path d="M228.375 303.22h-96.061l.046-94.446h96.067l-.052 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M322.827 397.666h-94.495l.044-96.018h94.498l-.047 96.018z" fill="#ff4900" fill-rule="nonzero"/><path d="M324.444 303.22h-97.636l.046-94.446h97.638l-.048 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M418.938 303.22h-96.064l.045-94.446h96.066l-.047 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M228.423 208.77H132.36l.045-94.445h96.066l-.05 94.446zM418.985 208.77H322.92l.044-94.445h96.069l-.048 94.446z" fill="#ffa300" fill-rule="nonzero"/><path d="M133.883 304.79H39.392l.044-96.017h94.496l-.049 96.017z" fill="#ff7000" fill-rule="nonzero"/><path d="M133.929 208.77H39.437l.044-95.445h94.496l-.048 95.445z" fill="#ffa300" fill-rule="nonzero"/><path d="M133.976 114.325H39.484l.044-94.448h94.497l-.05 94.448zM511.954 115.325h-94.493l.044-95.448h94.497l-.048 95.448z" fill="#ffce00" fill-rule="nonzero"/><path d="M133.836 399.667H39.345l.044-96.447h94.496l-.049 96.447z" fill="#ff4900" fill-rule="nonzero"/><path d="M133.79 492.117H39.3l.044-94.448h94.496l-.049 94.448z" fill="#ff0107" fill-rule="nonzero"/><path d="M511.862 303.22h-94.495l.046-94.446h94.496l-.047 94.446z" fill="#ff7000" fill-rule="nonzero"/><path d="M511.907 208.77h-94.493l.044-94.445h94.496l-.047 94.446z" fill="#ffa300" fill-rule="nonzero"/><path d="M511.815 398.666h-94.493l.044-95.447h94.496l-.047 95.447z" fill="#ff4900" fill-rule="nonzero"/><path d="M511.77 492.117h-94.496l.046-94.448h94.496l-.047 94.448z" fill="#ff0107" fill-rule="nonzero"/></svg>

After

Width:  |  Height:  |  Size: 2.9 KiB

View File

@@ -75,7 +75,8 @@ export default function Page() {
},
{} as Record<SourceCategory, SourceMetadata[]>
);
}, [sources, searchTerm]);
}, [sources, filterSources, searchTerm]);
const handleKeyPress = (e: React.KeyboardEvent<HTMLInputElement>) => {
if (e.key === "Enter") {
const filteredCategories = Object.entries(categorizedSources).filter(

View File

@@ -1,7 +1,7 @@
"use client";
import { Option } from "@/components/Dropdown";
import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils";
import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types";
import { Separator } from "@/components/ui/separator";
import { Button } from "@/components/ui/button";
@@ -9,12 +9,11 @@ import { Textarea } from "@/components/ui/textarea";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
FormikProps,
useFormikContext,
} from "formik";
import {
@@ -27,7 +26,6 @@ import {
import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { Option } from "@/components/Dropdown";
import { addAssistantToList } from "@/lib/assistants/updateAssistantPreferences";
import { checkLLMSupportsImageInput, destructureValue } from "@/lib/llm/utils";
import { ToolSnapshot } from "@/lib/tools/interfaces";
@@ -41,10 +39,9 @@ import {
} from "@/components/ui/tooltip";
import Link from "next/link";
import { useRouter } from "next/navigation";
import { useEffect, useState } from "react";
import { FiInfo, FiX } from "react-icons/fi";
import { useEffect, useMemo, useState } from "react";
import { FiInfo, FiRefreshCcw } from "react-icons/fi";
import * as Yup from "yup";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, PersonaCategory, StarterMessage } from "./interfaces";
@@ -66,6 +63,9 @@ import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { buildImgUrl } from "@/app/chat/files/images/utils";
import { LlmList } from "@/components/llm/LLMList";
import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList";
import { Input } from "@/components/ui/input";
import { CategoryCard } from "./CategoryCard";
@@ -129,12 +129,14 @@ export function AssistantEditor({
];
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
const [hasEditedStarterMessage, setHasEditedStarterMessage] = useState(false);
const [showPersonaCategory, setShowPersonaCategory] = useState(!admin);
// state to persist across formik reformatting
const [defautIconColor, _setDeafultIconColor] = useState(
colorOptions[Math.floor(Math.random() * colorOptions.length)]
);
const [isRefreshing, setIsRefreshing] = useState(false);
const [defaultIconShape, setDefaultIconShape] = useState<any>(null);
@@ -148,6 +150,10 @@ export function AssistantEditor({
const [removePersonaImage, setRemovePersonaImage] = useState(false);
const autoStarterMessageEnabled = useMemo(
() => llmProviders.length > 0,
[llmProviders.length]
);
const isUpdate = existingPersona !== undefined && existingPersona !== null;
const existingPrompt = existingPersona?.prompts[0] ?? null;
const defaultProvider = llmProviders.find(
@@ -217,7 +223,24 @@ export function AssistantEditor({
existingPersona?.llm_model_provider_override ?? null,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
starter_messages: existingPersona?.starter_messages ?? [],
starter_messages: existingPersona?.starter_messages ?? [
{
name: "",
message: "",
},
{
name: "",
message: "",
},
{
name: "",
message: "",
},
{
name: "",
message: "",
},
],
enabled_tools_map: enabledToolsMap,
icon_color: existingPersona?.icon_color ?? defautIconColor,
icon_shape: existingPersona?.icon_shape ?? defaultIconShape,
@@ -228,6 +251,44 @@ export function AssistantEditor({
groups: existingPersona?.groups ?? [],
};
interface AssistantPrompt {
message: string;
name: string;
}
const debouncedRefreshPrompts = debounce(
async (values: any, setFieldValue: any) => {
if (!autoStarterMessageEnabled) {
return;
}
setIsRefreshing(true);
try {
const response = await fetch("/api/persona/assistant-prompt-refresh", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
name: values.name,
description: values.description,
document_set_ids: values.document_set_ids,
instructions: values.system_prompt || values.task_prompt,
}),
});
const data: AssistantPrompt = await response.json();
if (response.ok) {
setFieldValue("starter_messages", data);
}
} catch (error) {
console.error("Failed to refresh prompts:", error);
} finally {
setIsRefreshing(false);
}
},
1000
);
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
return (
@@ -421,6 +482,8 @@ export function AssistantEditor({
isSubmitting,
values,
setFieldValue,
errors,
...formikProps
}: FormikProps<any>) => {
function toggleToolInValues(toolId: number) {
@@ -445,6 +508,7 @@ export function AssistantEditor({
return (
<Form className="w-full text-text-950">
{/* Refresh starter messages when name or description changes */}
<div className="w-full flex gap-x-2 justify-center">
<Popover
open={isIconDropdownOpen}
@@ -984,6 +1048,91 @@ export function AssistantEditor({
</div>
</div>
<div className="mb-6 w-full flex flex-col">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
Starter Messages
</div>
</div>
<SubLabel>
Pre-configured messages that help users understand what this
assistant can do and how to interact with it effectively.
</SubLabel>
<div className="relative w-fit">
<TooltipProvider delayDuration={50}>
<Tooltip>
<TooltipTrigger asChild>
<div>
<Button
type="button"
size="sm"
onClick={() =>
debouncedRefreshPrompts(values, setFieldValue)
}
disabled={
!autoStarterMessageEnabled ||
isRefreshing ||
(Object.keys(errors).length > 0 &&
Object.keys(errors).some(
(key) => !key.startsWith("starter_messages")
))
}
className={`
px-3 py-2
mr-auto
my-2
flex gap-x-2
text-sm font-medium
rounded-lg shadow-sm
items-center gap-2
transition-colors duration-200
${
isRefreshing || !autoStarterMessageEnabled
? "bg-gray-100 text-gray-400 cursor-not-allowed"
: "bg-blue-50 text-blue-600 hover:bg-blue-100 active:bg-blue-200"
}
`}
>
<div className="flex items-center gap-x-2">
{isRefreshing ? (
<FiRefreshCcw className="w-4 h-4 animate-spin text-gray-400" />
) : (
<SwapIcon className="w-4 h-4 text-blue-600" />
)}
Generate
</div>
</Button>
</div>
</TooltipTrigger>
{!autoStarterMessageEnabled && (
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
No LLM providers configured. Generation is not
available.
</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
</div>
<div className="w-full">
<FieldArray
name="starter_messages"
render={(arrayHelpers: ArrayHelpers) => (
<StarterMessagesList
isRefreshing={isRefreshing}
values={values.starter_messages}
arrayHelpers={arrayHelpers}
touchStarterMessages={() => {
setHasEditedStarterMessage(true);
}}
/>
)}
/>
</div>
</div>
{admin && (
<AdvancedOptionsToggle
title="Categories"
@@ -1190,136 +1339,12 @@ export function AssistantEditor({
</>
)}
<div className="mb-6 flex flex-col">
<div className="flex gap-x-2 items-center">
<div className="block font-medium text-base">
Starter Messages (Optional){" "}
</div>
</div>
<SubLabel>
Add pre-defined messages to help users get started. Only
the first 4 will be displayed.
</SubLabel>
<FieldArray
name="starter_messages"
render={(
arrayHelpers: ArrayHelpers<StarterMessage[]>
) => (
<div>
{values.starter_messages &&
values.starter_messages.length > 0 &&
values.starter_messages.map(
(
starterMessage: StarterMessage,
index: number
) => {
return (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label small>Name</Label>
<SubLabel>
Shows up as the &quot;title&quot;
for this Starter Message. For
example, &quot;Write an email&quot;.
</SubLabel>
<Field
name={`starter_messages[${index}].name`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages[${index}].name`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label small>Message</Label>
<SubLabel>
The actual message to be sent as the
initial user message if a user
selects this starter prompt. For
example, &quot;Write me an email to
a client about a new billing feature
we just released.&quot;
</SubLabel>
<Field
name={`starter_messages[${index}].message`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
min-h-12
mr-4
line-clamp-
`}
as="textarea"
autoComplete="off"
/>
<ErrorMessage
name={`starter_messages[${index}].message`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() =>
arrayHelpers.remove(index)
}
/>
</div>
</div>
</div>
);
}
)}
<Button
onClick={() => {
arrayHelpers.push({
name: "",
description: "",
message: "",
});
}}
className="mt-3"
size="sm"
variant="next"
>
Add New
</Button>
</div>
)}
/>
</div>
<IsPublicGroupSelector
formikProps={{
values,
isSubmitting,
setFieldValue,
errors,
...formikProps,
}}
objectName="assistant"

View File

@@ -0,0 +1,198 @@
"use client";
import { ArrayHelpers, ErrorMessage, Field, useFormikContext } from "formik";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@radix-ui/react-tooltip";
import { useEffect } from "react";
import { FiInfo, FiTrash2, FiPlus } from "react-icons/fi";
import { StarterMessage } from "./interfaces";
import { Label } from "@/components/admin/connectors/Field";
export default function StarterMessagesList({
values,
arrayHelpers,
isRefreshing,
touchStarterMessages,
}: {
values: StarterMessage[];
arrayHelpers: ArrayHelpers;
isRefreshing: boolean;
touchStarterMessages: () => void;
}) {
const { handleChange } = useFormikContext();
// Group starter messages into rows of 2 for display purposes
const rows = values.reduce((acc: StarterMessage[][], curr, i) => {
if (i % 2 === 0) acc.push([curr]);
else acc[acc.length - 1].push(curr);
return acc;
}, []);
const canAddMore = values.length <= 6;
return (
<div className="mt-4 flex flex-col gap-6">
{rows.map((row, rowIndex) => (
<div key={rowIndex} className="flex items-start gap-4">
<div className="grid grid-cols-2 gap-6 w-full xl:w-fit">
{row.map((starterMessage, colIndex) => (
<div
key={rowIndex * 2 + colIndex}
className="bg-white max-w-full w-full xl:w-[500px] border border-border rounded-lg shadow-md transition-shadow duration-200 p-6"
>
<div className="space-y-5">
{isRefreshing ? (
<div className="w-full">
<div className="w-full">
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
</div>
<div>
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
<div className="h-10 w-full bg-gray-200 rounded animate-pulse" />
</div>
<div>
<div className="h-4 w-24 bg-gray-200 rounded animate-pulse mb-2" />
<div className="h-24 w-full bg-gray-200 rounded animate-pulse" />
</div>
</div>
) : (
<>
<div>
<div className="flex w-full items-center gap-x-1">
<Label
small
className="text-sm font-medium text-gray-700"
>
Name
</Label>
<TooltipProvider delayDuration={50}>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
Shows up as the &quot;title&quot; for this
Starter Message. For example, &quot;Write an
email.&quot;
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Field
name={`starter_messages.${
rowIndex * 2 + colIndex
}.name`}
className="mt-1 w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition"
autoComplete="off"
placeholder="Enter a name..."
onChange={(e: any) => {
touchStarterMessages();
handleChange(e);
}}
/>
<ErrorMessage
name={`starter_messages.${
rowIndex * 2 + colIndex
}.name`}
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<div>
<div className="flex w-full items-center gap-x-1">
<Label
small
className="text-sm font-medium text-gray-700"
>
Message
</Label>
<TooltipProvider delayDuration={50}>
<Tooltip>
<TooltipTrigger>
<FiInfo size={12} />
</TooltipTrigger>
<TooltipContent side="top" align="center">
<p className="bg-background-900 max-w-[200px] mb-1 text-sm rounded-lg p-1.5 text-white">
The actual message to be sent as the initial
user message.
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<Field
name={`starter_messages.${
rowIndex * 2 + colIndex
}.message`}
className="mt-1 text-sm w-full px-4 py-2.5 bg-background border border-border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-transparent transition min-h-[100px] resize-y"
as="textarea"
autoComplete="off"
placeholder="Enter the message..."
onChange={(e: any) => {
touchStarterMessages();
handleChange(e);
}}
/>
<ErrorMessage
name={`starter_messages.${
rowIndex * 2 + colIndex
}.message`}
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
</>
)}
</div>
</div>
))}
</div>
<button
type="button"
onClick={() => {
arrayHelpers.remove(rowIndex * 2 + 1);
arrayHelpers.remove(rowIndex * 2);
}}
className="p-1.5 bg-white border border-gray-200 rounded-full text-gray-400 hover:text-red-500 hover:border-red-200 transition-colors mt-2"
aria-label="Delete row"
>
<FiTrash2 size={14} />
</button>
</div>
))}
{canAddMore && (
<button
type="button"
onClick={() => {
arrayHelpers.push({
name: "",
message: "",
});
arrayHelpers.push({
name: "",
message: "",
});
}}
className="self-start flex items-center gap-2 px-4 py-2 bg-white border border-gray-200 rounded-lg text-gray-600 hover:bg-gray-50 hover:border-gray-300 transition-colors"
>
<FiPlus size={16} />
<span>Add Row</span>
</button>
)}
</div>
);
}

View File

@@ -1,8 +1,12 @@
import {
AnthropicIcon,
AmazonIcon,
AWSIcon,
AzureIcon,
CPUIcon,
MicrosoftIconSVG,
MistralIcon,
MetaIcon,
OpenAIIcon,
GeminiIcon,
OpenSourceIcon,
@@ -72,12 +76,25 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
switch (providerName) {
case "openai":
// Special cases for openai based on modelName
if (modelName?.toLowerCase().includes("amazon")) {
return AmazonIcon;
}
if (modelName?.toLowerCase().includes("phi")) {
return MicrosoftIconSVG;
}
if (modelName?.toLowerCase().includes("mistral")) {
return MistralIcon;
}
if (modelName?.toLowerCase().includes("llama")) {
return MetaIcon;
}
if (modelName?.toLowerCase().includes("gemini")) {
return GeminiIcon;
}
if (modelName?.toLowerCase().includes("claude")) {
return AnthropicIcon;
}
return OpenAIIcon; // Default for openai
case "anthropic":
return AnthropicIcon;

View File

@@ -2,7 +2,7 @@ import { useFormContext } from "@/components/context/FormContext";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { SettingsIcon } from "@/components/icons/icons";
import { Logo } from "@/components/Logo";
import { Logo } from "@/components/logo/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { credentialTemplates } from "@/lib/connectors/credentials";
import Link from "next/link";
@@ -62,11 +62,11 @@ export default function Sidebar() {
];
return (
<div className="flex flex-none w-[250px] bg-background text-default">
<div className="flex flex-none w-[250px] text-default">
<div
className={`
fixed
bg-background-100
bg-background-sidebar
h-screen
transition-all
bg-opacity-80

View File

@@ -6,7 +6,7 @@ import { useCallback, useEffect, useState } from "react";
import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "../waiting-on-verification/RequestNewVerificationEmail";
import { User } from "@/lib/types";
import { Logo } from "@/components/Logo";
import { Logo } from "@/components/logo/Logo";
export function Verify({ user }: { user: User | null }) {
const searchParams = useSearchParams();

View File

@@ -8,7 +8,7 @@ import { HealthCheckBanner } from "@/components/health/healthcheck";
import { User } from "@/lib/types";
import Text from "@/components/ui/text";
import { RequestNewVerificationEmail } from "./RequestNewVerificationEmail";
import { Logo } from "@/components/Logo";
import { Logo } from "@/components/logo/Logo";
export default async function Page() {
// catch cases where the backend is completely unreachable here

View File

@@ -273,6 +273,7 @@ export function ChatPage({
};
const llmOverrideManager = useLlmOverride(
llmProviders,
modelVersionFromSearchParams || (user?.preferences.default_model ?? null),
selectedChatSession,
defaultTemperature
@@ -319,9 +320,9 @@ export function ChatPage({
);
if (personaDefault) {
llmOverrideManager.setLlmOverride(personaDefault);
llmOverrideManager.updateLLMOverride(personaDefault);
} else if (user?.preferences.default_model) {
llmOverrideManager.setLlmOverride(
llmOverrideManager.updateLLMOverride(
destructureValue(user?.preferences.default_model)
);
}
@@ -1203,7 +1204,6 @@ export function ChatPage({
assistant_message_id: number;
frozenMessageMap: Map<number, Message>;
} = null;
try {
const mapKeys = Array.from(
currentMessageMap(completeMessageDetail).keys()
@@ -2146,7 +2146,7 @@ export function ChatPage({
page="chat"
ref={innerSidebarElementRef}
toggleSidebar={toggleSidebar}
toggled={toggledSidebar && !settings?.isMobile}
toggled={toggledSidebar}
backgroundToggled={toggledSidebar || showHistorySidebar}
existingChats={chatSessions}
currentChatSession={selectedChatSession}
@@ -2168,7 +2168,6 @@ export function ChatPage({
fixed
right-0
z-[1000]
bg-background
h-screen
transition-all
@@ -2218,8 +2217,6 @@ export function ChatPage({
{liveAssistant && (
<FunctionalHeader
toggleUserSettings={() => setUserSettingsToggled(true)}
liveAssistant={liveAssistant}
onAssistantChange={onAssistantChange}
sidebarToggled={toggledSidebar}
reset={() => setMessage("")}
page="chat"
@@ -2231,7 +2228,6 @@ export function ChatPage({
toggleSidebar={toggleSidebar}
currentChatSession={selectedChatSession}
documentSidebarToggled={documentSidebarToggled}
llmOverrideManager={llmOverrideManager}
/>
)}

View File

@@ -124,9 +124,9 @@ export default function RegenerateOption({
onHoverChange: (isHovered: boolean) => void;
onDropdownVisibleChange: (isVisible: boolean) => void;
}) {
const llmOverrideManager = useLlmOverride();
const { llmProviders } = useChatContext();
const llmOverrideManager = useLlmOverride(llmProviders);
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
const llmOptionsByProvider: {

View File

@@ -81,6 +81,8 @@ export function ChatDocumentDisplay({
}
};
const hasMetadata =
document.updated_at || Object.keys(document.metadata).length > 0;
return (
<div className={`opacity-100 ${modal ? "w-[90vw]" : "w-full"}`}>
<div
@@ -107,8 +109,14 @@ export function ChatDocumentDisplay({
: document.semantic_identifier || document.document_id}
</div>
</div>
<DocumentMetadataBlock modal={modal} document={document} />
<div className="line-clamp-3 pt-2 text-sm font-normal leading-snug text-gray-600">
{hasMetadata && (
<DocumentMetadataBlock modal={modal} document={document} />
)}
<div
className={`line-clamp-3 text-sm font-normal leading-snug text-gray-600 ${
hasMetadata ? "mt-2" : ""
}`}
>
{buildDocumentSummaryDisplay(
document.match_highlights,
document.blurb

View File

@@ -417,9 +417,7 @@ export function ChatInputBar({
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder={`Send a message ${
!settings?.isMobile ? "or try using @ or /" : ""
}`}
placeholder="Ask me anything.."
value={message}
onKeyDown={(event) => {
if (

View File

@@ -812,6 +812,7 @@ export const HumanMessage = ({
outline-none
placeholder-gray-400
resize-none
text-text-editing-message
pl-4
overflow-y-auto
pr-12
@@ -870,7 +871,6 @@ export const HumanMessage = ({
py-2
px-3
w-fit
bg-hover
bg-background-strong
text-sm
rounded-lg
@@ -896,15 +896,13 @@ export const HumanMessage = ({
<TooltipProvider delayDuration={1000}>
<Tooltip>
<TooltipTrigger>
<button
className="hover:bg-hover p-1.5 rounded"
<HoverableIcon
icon={<FiEdit2 className="text-gray-600" />}
onClick={() => {
setIsEditing(true);
setIsHovered(false);
}}
>
<FiEdit2 className="!h-4 !w-4" />
</button>
/>
</TooltipTrigger>
<TooltipContent>Edit</TooltipContent>
</Tooltip>

View File

@@ -35,7 +35,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
checkPersonaRequiresImageGeneration(currentAssistant);
const { llmProviders } = useChatContext();
const { setLlmOverride, temperature, updateTemperature } =
const { updateLLMOverride, temperature, updateTemperature } =
llmOverrideManager;
const [isTemperatureExpanded, setIsTemperatureExpanded] = useState(false);
@@ -60,7 +60,7 @@ export const LlmTab = forwardRef<HTMLDivElement, LlmTabProps>(
if (value == null) {
return;
}
setLlmOverride(destructureValue(value));
updateLLMOverride(destructureValue(value));
if (chatSessionId) {
updateModelOverrideForChatSession(chatSessionId, value as string);
}

View File

@@ -11,13 +11,10 @@ import { createFolder } from "../folders/FolderManagement";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import {
AssistantsIconSkeleton,
ClosedBookIcon,
} from "@/components/icons/icons";
import { AssistantsIconSkeleton } from "@/components/icons/icons";
import { PagesTab } from "./PagesTab";
import { pageType } from "./types";
import LogoType from "@/components/header/LogoType";
import LogoWithText from "@/components/header/LogoWithText";
interface HistorySidebarProps {
page: pageType;
@@ -102,16 +99,19 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
flex
flex-col relative
h-screen
pt-2
transition-transform
`}
>
<LogoType
showArrow={true}
toggled={toggled}
page={page}
toggleSidebar={toggleSidebar}
explicitlyUntoggle={explicitlyUntoggle}
/>
<div className="pl-2">
<LogoWithText
showArrow={true}
toggled={toggled}
page={page}
toggleSidebar={toggleSidebar}
explicitlyUntoggle={explicitlyUntoggle}
/>
</div>
{page == "chat" && (
<div className="mx-3 mt-4 gap-y-1 flex-col text-text-history-sidebar-button flex gap-x-1.5 items-center items-center">
<Link

View File

@@ -1,53 +1,83 @@
"use client";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
import { Logo } from "@/components/logo/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import Link from "next/link";
import { useContext } from "react";
import { FiSidebar } from "react-icons/fi";
import { LogoType } from "@/components/logo/Logo";
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
import { useRouter } from "next/navigation";
export function LogoComponent({
enterpriseSettings,
backgroundToggled,
show,
isAdmin,
}: {
enterpriseSettings: EnterpriseSettings | null;
backgroundToggled?: boolean;
show?: boolean;
isAdmin?: boolean;
}) {
const router = useRouter();
return (
<button
onClick={isAdmin ? () => router.push("/chat") : () => {}}
className={`max-w-[200px] ${
!show && "mobile:hidden"
} flex items-center gap-x-1`}
>
{enterpriseSettings && enterpriseSettings.application_name ? (
<>
<div className="flex-none my-auto">
<Logo height={24} width={24} />
</div>
<div className="w-full">
<HeaderTitle backgroundToggled={backgroundToggled}>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-left text-subtle">Powered by Onyx</p>
)}
</div>
</>
) : (
<LogoType />
)}
</button>
);
}
export default function FixedLogo({
// Whether the sidebar is toggled or not
backgroundToggled,
}: {
backgroundToggled?: boolean;
}) {
const combinedSettings = useContext(SettingsContext);
const settings = combinedSettings?.settings;
const enterpriseSettings = combinedSettings?.enterpriseSettings;
return (
<>
<Link
href="/chat"
className="fixed cursor-pointer flex z-40 left-4 top-2 h-8"
className="fixed cursor-pointer flex z-40 left-4 top-3 h-8"
>
<div className="max-w-[200px] mobile:hidden flex items-center gap-x-1 my-auto">
<div className="flex-none my-auto">
<Logo height={24} width={24} />
</div>
<div className="w-full">
{enterpriseSettings && enterpriseSettings.application_name ? (
<div>
<HeaderTitle backgroundToggled={backgroundToggled}>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Onyx</p>
)}
</div>
) : (
<HeaderTitle backgroundToggled={backgroundToggled}>
Onyx
</HeaderTitle>
)}
</div>
</div>
<LogoComponent
enterpriseSettings={enterpriseSettings!}
backgroundToggled={backgroundToggled}
/>
</Link>
<div className="mobile:hidden fixed left-4 bottom-4">
<FiSidebar className="text-text-mobile-sidebar" />
<FiSidebar
className={`${
backgroundToggled
? "text-text-mobile-sidebar-toggled"
: "text-text-mobile-sidebar-untoggled"
}`}
/>
</div>
</>
);

View File

@@ -14,7 +14,6 @@ import { buildClientUrl } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import { EnterpriseSettings, GatingType } from "./admin/settings/interfaces";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
import { AppProvider } from "@/components/context/AppProvider";
import { PHProvider } from "./providers";
@@ -23,6 +22,7 @@ import CardSection from "@/components/admin/CardSection";
import { Suspense } from "react";
import PostHogPageView from "./PostHogPageView";
import Script from "next/script";
import { LogoType } from "@/components/logo/Logo";
const inter = Inter({
subsets: ["latin"],
@@ -115,8 +115,7 @@ export default async function RootLayout({
return getPageContent(
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-2 flex items-center max-w-[175px]">
<HeaderTitle>Onyx</HeaderTitle>
<Logo height={40} width={40} />
<LogoType />
</div>
<CardSection className="max-w-md">
@@ -124,7 +123,8 @@ export default async function RootLayout({
<p className="text-text-500">
Your Onyx instance was not configured properly and your settings
could not be loaded. This could be due to an admin configuration
issue or an incomplete setup.
issue, an incomplete setup, or backend services that may not be up
and running yet.
</p>
<p className="mt-4">
If you&apos;re an admin, please check{" "}
@@ -144,7 +144,7 @@ export default async function RootLayout({
community on{" "}
<a
className="text-link"
href="https://onyx.app?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ"
target="_blank"
rel="noopener noreferrer"
>
@@ -160,8 +160,7 @@ export default async function RootLayout({
return getPageContent(
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-2 flex items-center max-w-[175px]">
<HeaderTitle>Onyx</HeaderTitle>
<Logo height={40} width={40} />
<LogoType />
</div>
<CardSection className="w-full max-w-md">
<h1 className="text-2xl font-bold mb-4 text-error">

View File

@@ -1,38 +0,0 @@
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { HeaderTitle } from "./header/HeaderTitle";
import LogoType, { Logo } from "./Logo";
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
export default function LogoTypeContainer({
enterpriseSettings,
}: {
enterpriseSettings: EnterpriseSettings | null;
}) {
const onlyLogo =
!enterpriseSettings ||
!enterpriseSettings.use_custom_logo ||
!enterpriseSettings.application_name;
return (
<div className="flex justify-start items-start w-full gap-x-1 my-auto">
<div className="flex-none w-fit mr-auto my-auto">
{onlyLogo ? <LogoType /> : <Logo height={24} width={24} />}
</div>
{!onlyLogo && (
<div className="w-full">
{enterpriseSettings && enterpriseSettings.application_name ? (
<div>
<HeaderTitle>{enterpriseSettings.application_name}</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Onyx</p>
)}
</div>
) : (
<HeaderTitle>Onyx</HeaderTitle>
)}
</div>
)}
</div>
);
}

View File

@@ -27,7 +27,9 @@ export function MetadataBadge({
size: 12,
className: flexNone ? "flex-none" : "mr-0.5 my-auto",
})}
<div className="my-auto flex">{value}</div>
<p className="max-w-[6rem] text-ellipsis overflow-hidden truncate whitespace-nowrap">
{value}lllaasfasdf
</p>
</div>
);
}

View File

@@ -1,4 +1,4 @@
import { Logo } from "./Logo";
import { Logo } from "./logo/Logo";
import { useContext } from "react";
import { SettingsContext } from "./settings/SettingsProvider";

View File

@@ -136,7 +136,7 @@ export function UserDropdown({
<div
className="
my-auto
bg-background-strong
bg-userdropdown-background
ring-2
ring-transparent
group-hover:ring-background-300/50

View File

@@ -2,7 +2,7 @@
"use client";
import React, { useContext } from "react";
import Link from "next/link";
import { Logo } from "@/components/Logo";
import { Logo } from "@/components/logo/Logo";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { SettingsContext } from "@/components/settings/SettingsProvider";
@@ -14,6 +14,8 @@ import {
TooltipTrigger,
} from "@/components/ui/tooltip";
import { CgArrowsExpandUpLeft } from "react-icons/cg";
import LogoWithText from "@/components/header/LogoWithText";
import { LogoComponent } from "@/app/chat/shared_chat_search/FixedLogo";
interface Item {
name: string | JSX.Element;
@@ -32,36 +34,22 @@ export function AdminSidebar({ collections }: { collections: Collection[] }) {
return null;
}
const settings = combinedSettings.settings;
const enterpriseSettings = combinedSettings.enterpriseSettings;
return (
<div className="text-text-settings-sidebar pl-0">
<nav className="space-y-2">
<div className="w-full ml-4 h-8 justify-start mb-4 flex">
<div className="flex items-center gap-x-1 my-auto">
<div className="flex-none my-auto">
<Logo height={24} width={24} />
</div>
<div className="w-full">
{enterpriseSettings && enterpriseSettings.application_name ? (
<div>
<HeaderTitle>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Onyx</p>
)}
</div>
) : (
<HeaderTitle>Onyx</HeaderTitle>
)}
</div>
</div>
<div className="w-full ml-4 mt-1 h-8 justify-start mb-4 flex">
<LogoComponent
show={true}
enterpriseSettings={enterpriseSettings!}
backgroundToggled={false}
isAdmin={true}
/>
</div>
<div className="flex w-full justify-center">
<Link href="/chat">
<button className="text-sm flex items-center block w-52 py-2.5 flex px-2 text-left hover:bg-opacity-80 cursor-pointer rounded">
<button className="text-sm hover:bg-background-settings-hover flex items-center block w-52 py-2.5 flex px-2 text-left hover:bg-opacity-80 cursor-pointer rounded">
<CgArrowsExpandUpLeft className="my-auto" size={18} />
<p className="ml-1 break-words line-clamp-2 ellipsis leading-none">
Exit Admin

View File

@@ -45,7 +45,7 @@ export default function AssistantBanner({
<div
className={`${
mobile ? "w-full" : "w-36 mx-3"
} flex py-1.5 scale-[1.] rounded-full border border-background-150 justify-center items-center gap-x-2 py-1 px-3 hover:bg-background-125 transition-colors cursor-pointer`}
} flex py-1.5 scale-[1.] rounded-full border border-border-recent-assistants justify-center items-center gap-x-2 py-1 px-3 hover:bg-background-125 transition-colors cursor-pointer`}
onClick={() => onAssistantChange(assistant)}
>
<AssistantIcon
@@ -53,7 +53,7 @@ export default function AssistantBanner({
size="xs"
assistant={assistant}
/>
<span className="font-semibold text-text-800 text-xs truncate max-w-[120px]">
<span className="font-semibold text-text-recent-assistants text-xs truncate max-w-[120px]">
{assistant.name}
</span>
</div>

View File

@@ -1,4 +1,4 @@
import { Logo } from "../Logo";
import { Logo } from "../logo/Logo";
export default function AuthFlowContainer({
children,

View File

@@ -199,7 +199,7 @@ const AssistantSelector = ({
onSelect={(value: string | null) => {
if (value == null) return;
const { modelName, name, provider } = destructureValue(value);
llmOverrideManager.setLlmOverride({
llmOverrideManager.updateLLMOverride({
name,
provider,
modelName,

View File

@@ -1,18 +1,16 @@
"use client";
import { User } from "@/lib/types";
import { UserDropdown } from "../UserDropdown";
import { FiShare2 } from "react-icons/fi";
import { SetStateAction, useContext, useEffect } from "react";
import { NewChatIcon } from "../icons/icons";
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
import { ChatSession } from "@/app/chat/interfaces";
import Link from "next/link";
import { pageType } from "@/app/chat/sessionSidebar/types";
import { useRouter } from "next/navigation";
import { ChatBanner } from "@/app/chat/ChatBanner";
import LogoType from "../header/LogoType";
import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmOverrideManager } from "@/lib/hooks";
import LogoWithText from "../header/LogoWithText";
import { NewChatIcon } from "../icons/icons";
import { SettingsContext } from "../settings/SettingsProvider";
export default function FunctionalHeader({
page,
@@ -21,9 +19,6 @@ export default function FunctionalHeader({
toggleSidebar = () => null,
reset = () => null,
sidebarToggled,
liveAssistant,
onAssistantChange,
llmOverrideManager,
documentSidebarToggled,
toggleUserSettings,
}: {
@@ -34,11 +29,9 @@ export default function FunctionalHeader({
currentChatSession?: ChatSession | null | undefined;
setSharingModalVisible?: (value: SetStateAction<boolean>) => void;
toggleSidebar?: () => void;
liveAssistant?: Persona;
onAssistantChange?: (assistant: Persona) => void;
llmOverrideManager?: LlmOverrideManager;
toggleUserSettings?: () => void;
}) {
const settings = useContext(SettingsContext);
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.metaKey || event.ctrlKey) {
@@ -76,10 +69,11 @@ export default function FunctionalHeader({
return (
<div className="left-0 sticky top-0 z-20 w-full relative flex">
<div className="items-end flex mt-2 cursor-pointer text-text-700 relative flex w-full">
<LogoType
<LogoWithText
assistantId={currentChatSession?.persona_id}
page={page}
toggleSidebar={toggleSidebar}
toggled={sidebarToggled && !settings?.isMobile}
handleNewChat={handleNewChat}
/>
<div className="mt-2 flex w-full h-8">
@@ -103,18 +97,19 @@ export default function FunctionalHeader({
</div>
<div className="invisible">
<LogoType
<LogoWithText
page={page}
toggled={sidebarToggled}
toggleSidebar={toggleSidebar}
handleNewChat={handleNewChat}
/>
</div>
<div className="absolute right-0 top-0 flex gap-x-2">
<div className="absolute right-0 mobile:top-2 desktop:top-0 flex">
{setSharingModalVisible && (
<div
onClick={() => setSharingModalVisible(true)}
className="mobile:hidden my-auto rounded cursor-pointer hover:bg-hover-light"
className="mobile:hidden mr-2 my-auto rounded cursor-pointer hover:bg-hover-light"
>
<FiShare2 size="18" />
</div>
@@ -126,7 +121,7 @@ export default function FunctionalHeader({
/>
</div>
<Link
className="desktop:hidden my-auto"
className="desktop:hidden ml-2 my-auto"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA &&

View File

@@ -2,7 +2,7 @@ import { useEmbeddingFormContext } from "@/components/context/EmbeddingContext";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { SettingsIcon } from "@/components/icons/icons";
import { Logo } from "@/components/Logo";
import { Logo } from "@/components/logo/Logo";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Link from "next/link";
import { useContext } from "react";

View File

@@ -10,7 +10,8 @@ export function HeaderTitle({
backgroundToggled?: boolean;
}) {
const isString = typeof children === "string";
const textSize = isString && children.length > 10 ? "text-xl" : "text-2xl";
const textSize =
isString && children.length > 10 ? "text-lg mb-[4px] " : "text-2xl";
return (
<h1
@@ -18,7 +19,7 @@ export function HeaderTitle({
backgroundToggled
? "text-text-sidebar-toggled-header"
: "text-text-sidebar-header"
} break-words line-clamp-2 ellipsis text-strong overflow-visible leading-none font-bold`}
} break-words text-left line-clamp-2 ellipsis text-strong overflow-hidden leading-none font-bold`}
>
{children}
</h1>

View File

@@ -2,10 +2,7 @@
import { useContext } from "react";
import { FiSidebar } from "react-icons/fi";
import { SettingsContext } from "../settings/SettingsProvider";
import {
NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED,
NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA,
} from "@/lib/constants";
import { NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA } from "@/lib/constants";
import { LeftToLineIcon, NewChatIcon, RightToLineIcon } from "../icons/icons";
import {
Tooltip,
@@ -14,11 +11,11 @@ import {
TooltipTrigger,
} from "@/components/ui/tooltip";
import { pageType } from "@/app/chat/sessionSidebar/types";
import { Logo } from "../Logo";
import { HeaderTitle } from "./HeaderTitle";
import { Logo } from "../logo/Logo";
import Link from "next/link";
import { LogoComponent } from "@/app/chat/shared_chat_search/FixedLogo";
export default function LogoType({
export default function LogoWithText({
toggleSidebar,
hideOnMobile,
handleNewChat,
@@ -39,50 +36,48 @@ export default function LogoType({
}) {
const combinedSettings = useContext(SettingsContext);
const enterpriseSettings = combinedSettings?.enterpriseSettings;
const useLogoType =
!enterpriseSettings?.use_custom_logo &&
!enterpriseSettings?.application_name;
return (
<div
className={`${
hideOnMobile && "mobile:hidden"
} z-[100] mt-2 h-8 mb-auto shrink-0 flex items-center text-xl`}
} z-[100] ml-2 mt-1 h-8 mb-auto shrink-0 flex gap-x-0 items-center text-xl`}
>
{toggleSidebar && page == "chat" ? (
<button
onClick={() => toggleSidebar()}
className="flex gap-x-2 items-center ml-4 desktop:invisible "
className="flex gap-x-2 items-center ml-0 desktop:hidden "
>
<FiSidebar size={20} className="text-text-mobile-sidebar" />
{!showArrow && (
{!toggled ? (
<Logo className="desktop:hidden -my-2" height={24} width={24} />
) : (
<LogoComponent
show={toggled}
enterpriseSettings={enterpriseSettings!}
backgroundToggled={toggled}
/>
)}
<FiSidebar
size={20}
className={`text-text-mobile-sidebar ${toggled && "mobile:hidden"}`}
/>
</button>
) : (
<div className="mr-1 invisible mb-auto h-6 w-6">
<Logo height={24} width={24} />
lll
</div>
)}
<div
className={`${
showArrow ? "desktop:invisible" : "invisible"
} break-words inline-block w-fit ml-2 text-text-700 text-xl`}
} break-words inline-block w-fit text-text-700 text-xl`}
>
<div className="max-w-[175px]">
{enterpriseSettings && enterpriseSettings.application_name ? (
<div className="w-full">
<HeaderTitle backgroundToggled={toggled}>
{enterpriseSettings.application_name}
</HeaderTitle>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<p className="text-xs text-subtle">Powered by Onyx</p>
)}
</div>
) : (
<HeaderTitle backgroundToggled={toggled}>Onyx</HeaderTitle>
)}
</div>
<LogoComponent
enterpriseSettings={enterpriseSettings!}
backgroundToggled={toggled}
/>
</div>
{page == "chat" && !showArrow && (
@@ -90,7 +85,7 @@ export default function LogoType({
<Tooltip>
<TooltipTrigger asChild>
<Link
className="mb-auto mobile:hidden"
className="my-auto mobile:hidden"
href={
`/${page}` +
(NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA && assistantId
@@ -130,10 +125,14 @@ export default function LogoType({
}}
>
{!toggled && !combinedSettings?.isMobile ? (
<RightToLineIcon className="text-sidebar-toggle" />
<RightToLineIcon className="mobile:hidden text-sidebar-toggle" />
) : (
<LeftToLineIcon className="text-sidebar-toggle" />
<LeftToLineIcon className="mobile:hidden text-sidebar-toggle" />
)}
<FiSidebar
size={20}
className="hidden mobile:block text-text-mobile-sidebar"
/>
</button>
</TooltipTrigger>
<TooltipContent>

View File

@@ -39,7 +39,10 @@ import Image, { StaticImageData } from "next/image";
import jiraSVG from "../../../public/Jira.svg";
import confluenceSVG from "../../../public/Confluence.svg";
import openAISVG from "../../../public/Openai.svg";
import amazonSVG from "../../../public/Amazon.svg";
import geminiSVG from "../../../public/Gemini.svg";
import metaSVG from "../../../public/Meta.svg";
import mistralSVG from "../../../public/Mistral.svg";
import openSourceIcon from "../../../public/OpenSource.png";
import litellmIcon from "../../../public/LiteLLM.jpg";
@@ -49,6 +52,7 @@ import asanaIcon from "../../../public/Asana.png";
import anthropicSVG from "../../../public/Anthropic.svg";
import nomicSVG from "../../../public/nomic.svg";
import microsoftIcon from "../../../public/microsoft.png";
import microsoftSVG from "../../../public/Microsoft.svg";
import mixedBreadSVG from "../../../public/Mixedbread.png";
import OCIStorageSVG from "../../../public/OCI.svg";
@@ -1104,6 +1108,26 @@ export const GeminiIcon = ({
className = defaultTailwindCSS,
}: IconProps) => <LogoIcon size={size} className={className} src={geminiSVG} />;
export const AmazonIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => <LogoIcon size={size} className={className} src={amazonSVG} />;
export const MetaIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => <LogoIcon size={size} className={className} src={metaSVG} />;
export const MicrosoftIconSVG = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => <LogoIcon size={size} className={className} src={microsoftSVG} />;
export const MistralIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => <LogoIcon size={size} className={className} src={mistralSVG} />;
export const VoyageIcon = ({
size = 16,
className = defaultTailwindCSS,

View File

@@ -1,7 +1,7 @@
"use client";
import { useContext } from "react";
import { SettingsContext } from "./settings/SettingsProvider";
import { SettingsContext } from "../settings/SettingsProvider";
import Image from "next/image";
export function Logo({
@@ -45,10 +45,10 @@ export function Logo({
);
}
export default function LogoType() {
export function LogoType() {
return (
<Image
className="max-h-8 mr-auto "
className="max-h-8 w-full mr-auto "
src="/logotype.png"
alt="Logo"
width={2640}

View File

@@ -496,7 +496,6 @@ export function HorizontalSourceSelector({
max-w-64
border-border
rounded-lg
bg-background
max-h-96
overflow-y-scroll
overscroll-contain
@@ -508,7 +507,6 @@ export function HorizontalSourceSelector({
w-fit
gap-x-1
hover:bg-hover
bg-hover-light
flex
items-center
bg-background-search-filter
@@ -522,7 +520,7 @@ export function HorizontalSourceSelector({
</div>
</PopoverTrigger>
<PopoverContent
className="bg-background border-border border rounded-md z-[200] p-0"
className="bg-background-search-filter border-border border rounded-md z-[200] p-0"
align="start"
>
<Calendar
@@ -541,7 +539,7 @@ export function HorizontalSourceSelector({
selectValue: timeRange?.selectValue || "",
});
}}
className="rounded-md "
className="rounded-md"
/>
</PopoverContent>
</Popover>

View File

@@ -94,7 +94,22 @@ export function TagFilter({
<div
key={tag.tag_key + tag.tag_value}
onClick={() => onSelectTag(tag)}
className="max-w-full break-all line-clamp-1 text-ellipsis flex text-sm border border-border py-0.5 px-2 rounded cursor-pointer bg-background hover:bg-hover"
className={`
max-w-full
break-all
line-clamp-1
text-ellipsis
flex
text-sm
border
border-border
py-0.5
px-2
rounded
cursor-pointer
bg-background-search-filter
hover:bg-background-search-filter-dropdown
`}
>
{tag.tag_key}
<b>=</b>
@@ -121,7 +136,7 @@ export function TagFilter({
>
<div
ref={popupRef}
className="p-2 border border-border rounded shadow-lg w-72 bg-background"
className="p-2 border border-border rounded shadow-lg w-72 bg-background-search-filter"
>
<div className="flex border-b border-border font-medium pb-1 text-xs mb-2">
<FiTag className="mr-1 my-auto" />
@@ -144,7 +159,11 @@ export function TagFilter({
cursor-pointer
bg-background
hover:bg-hover
${selectedTags.includes(tag) ? "bg-hover" : ""}
${
selectedTags.includes(tag)
? "bg-background-search-filter-dropdown"
: ""
}
`}
>
{tag.tag_key}

View File

@@ -54,7 +54,8 @@ function Calendar({
day_range_middle:
"aria-selected:bg-calendar-range-middle aria-selected:text-calendar-text-in-range dark:aria-selected:bg-calendar-range-middle-dark dark:aria-selected:text-calendar-text-in-range-dark",
day_hidden: "invisible",
day_range_start: "bg-white text-text-900",
day_range_start:
"bg-calendar-background-selected ring-calendar-ring-selected ring text-text-900",
...classNames,
}}
components={{

View File

@@ -11,12 +11,16 @@ import { errorHandlingFetcher } from "./fetcher";
import { useContext, useEffect, useState } from "react";
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
import { SourceMetadata } from "./search/interfaces";
import { destructureValue } from "./llm/utils";
import { destructureValue, structureValue } from "./llm/utils";
import { ChatSession } from "@/app/chat/interfaces";
import { UsersResponse } from "./users/interfaces";
import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { PersonaCategory } from "@/app/admin/assistants/interfaces";
import {
LLMProvider,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
const CREDENTIAL_URL = "/api/manage/admin/credential";
@@ -157,7 +161,7 @@ export interface LlmOverride {
export interface LlmOverrideManager {
llmOverride: LlmOverride;
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
updateLLMOverride: (newOverride: LlmOverride) => void;
globalDefault: LlmOverride;
setGlobalDefault: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null;
@@ -165,52 +169,64 @@ export interface LlmOverrideManager {
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
}
export function useLlmOverride(
llmProviders: LLMProviderDescriptor[],
globalModel?: string | null,
currentChatSession?: ChatSession,
defaultTemperature?: number
): LlmOverrideManager {
const getValidLlmOverride = (
overrideModel: string | null | undefined
): LlmOverride => {
if (overrideModel) {
const model = destructureValue(overrideModel);
const provider = llmProviders.find(
(p) =>
p.model_names.includes(model.modelName) &&
p.provider === model.provider
);
if (provider) {
return { ...model, name: provider.name };
}
}
return { name: "", provider: "", modelName: "" };
};
const [globalDefault, setGlobalDefault] = useState<LlmOverride>(
globalModel != null
? destructureValue(globalModel)
: {
name: "",
provider: "",
modelName: "",
}
getValidLlmOverride(globalModel)
);
const updateLLMOverride = (newOverride: LlmOverride) => {
setLlmOverride(
getValidLlmOverride(
structureValue(
newOverride.name,
newOverride.provider,
newOverride.modelName
)
)
);
};
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? destructureValue(currentChatSession.current_alternate_model)
: {
name: "",
provider: "",
modelName: "",
}
? getValidLlmOverride(currentChatSession.current_alternate_model)
: { name: "", provider: "", modelName: "" }
);
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
setLlmOverride(
chatSession && chatSession.current_alternate_model
? destructureValue(chatSession.current_alternate_model)
? getValidLlmOverride(chatSession.current_alternate_model)
: globalDefault
);
};
const [temperature, setTemperature] = useState<number | null>(
defaultTemperature != undefined ? defaultTemperature : 0
defaultTemperature !== undefined ? defaultTemperature : 0
);
useEffect(() => {
setGlobalDefault(
globalModel != null
? destructureValue(globalModel)
: {
name: "",
provider: "",
modelName: "",
}
);
}, [globalModel]);
setGlobalDefault(getValidLlmOverride(globalModel));
}, [globalModel, llmProviders]);
useEffect(() => {
setTemperature(defaultTemperature !== undefined ? defaultTemperature : 0);
@@ -233,7 +249,7 @@ export function useLlmOverride(
return {
updateModelOverrideForChatSession,
llmOverride,
setLlmOverride,
updateLLMOverride,
globalDefault,
setGlobalDefault,
temperature,
@@ -283,6 +299,7 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
// OpenAI models
"o1-mini": "O1 Mini",
"o1-preview": "O1 Preview",
"o1-2024-12-17": "O1",
"gpt-4": "GPT 4",
"gpt-4o": "GPT 4o",
"gpt-4o-2024-08-06": "GPT 4o (Structured Outputs)",
@@ -302,6 +319,21 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"gpt-3.5-turbo-16k-0613": "GPT 3.5 Turbo 16k (June 2023)",
"gpt-3.5-turbo-0301": "GPT 3.5 Turbo (March 2023)",
// Amazon models
"amazon.nova-micro@v1": "Amazon Nova Micro",
"amazon.nova-lite@v1": "Amazon Nova Lite",
"amazon.nova-pro@v1": "Amazon Nova Pro",
// Meta models
"llama-3.2-90b-vision-instruct": "Llama 3.2 90B",
"llama-3.2-11b-vision-instruct": "Llama 3.2 11B",
"llama-3.3-70b-instruct": "Llama 3.3 70B",
// Microsoft models
"phi-3.5-mini-instruct": "Phi 3.5 Mini",
"phi-3.5-moe-instruct": "Phi 3.5 MoE",
"phi-3.5-vision-instruct": "Phi 3.5 Vision",
// Anthropic models
"claude-3-opus-20240229": "Claude 3 Opus",
"claude-3-sonnet-20240229": "Claude 3 Sonnet",
@@ -313,6 +345,9 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet (New)",
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
"claude-3-5-haiku-20241022": "Claude 3.5 Haiku",
"claude-3-5-haiku@20241022": "Claude 3.5 Haiku",
"claude-3.5-haiku@20241022": "Claude 3.5 Haiku",
// Google Models
"gemini-1.5-pro": "Gemini 1.5 Pro",
@@ -321,6 +356,11 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"gemini-1.5-flash-001": "Gemini 1.5 Flash",
"gemini-1.5-pro-002": "Gemini 1.5 Pro (v2)",
"gemini-1.5-flash-002": "Gemini 1.5 Flash (v2)",
"gemini-2.0-flash-exp": "Gemini 2.0 Flash (Experimental)",
// Mistral Models
"mistral-large-2411": "Mistral Large 24.11",
"mistral-large@2411": "Mistral Large 24.11",
// Bedrock models
"meta.llama3-1-70b-instruct-v1:0": "Llama 3.1 70B",

View File

@@ -74,6 +74,8 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
// custom claude names
"claude-3.5-sonnet-v2@20241022",
// claude names with AWS Bedrock Suffix
"claude-3-opus-20240229-v1:0",
"claude-3-sonnet-20240229-v1:0",
@@ -93,6 +95,13 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"gemini-1.5-flash-001",
"gemini-1.5-pro-002",
"gemini-1.5-flash-002",
"gemini-2.0-flash-exp",
// amazon models
"amazon.nova-lite@v1",
"amazon.nova-pro@v1",
// meta models
"llama-3.2-90b-vision-instruct",
"llama-3.2-11b-vision-instruct"
];
export function checkLLMSupportsImageInput(model: string) {

View File

@@ -108,6 +108,8 @@ module.exports = {
"background-search-filter": "var(--background-100)",
"background-search-filter-dropdown": "var(--background-100)",
"user-bubble": "var(--user-bubble)",
// colors for sidebar in chat, search, and manage settings
"background-sidebar": "var(--background-100)",
"background-chatbar": "var(--background-100)",
@@ -141,6 +143,14 @@ module.exports = {
// Background for chat messages (user bubbles)
user: "var(--user-bubble)",
"userdropdown-background": "var(--background-150)",
"text-mobile-sidebar-toggled": "var(--text-800)",
"text-mobile-sidebar-untoggled": "var(--text-500)",
"text-editing-message": "var(--text-800)",
"background-sidebar": "var(--background-100)",
"background-search-filter": "var(--background-100)",
"background-search-filter-dropdown": "var(--background-hover)",
"background-toggle": "var(--background-100)",
// Colors for the search toggle buttons
@@ -200,6 +210,8 @@ module.exports = {
"calendar-today-bg-dark": "var(--background-800)",
"calendar-today-text": "var(--text-800)",
"calendar-today-text-dark": "var(--text-200)",
"calendar-background-selected": "var(--white)",
"calendar-ring-selected": "var(--background-900)",
"user-text": "var(--text-800)",
@@ -350,6 +362,39 @@ module.exports = {
fontStyle: {
"token-italic": "italic",
},
calendar: {
// Light mode
"bg-selected": "#4B5563",
"bg-outside-selected": "rgba(75, 85, 99, 0.2)",
"text-muted": "#6B7280",
"text-selected": "#FFFFFF",
"range-start": "#000000",
"range-middle": "#F3F4F6",
"range-end": "#000000",
"text-in-range": "#1F2937",
// Dark mode
"bg-selected-dark": "#6B7280",
"bg-outside-selected-dark": "rgba(107, 114, 128, 0.2)",
"text-muted-dark": "#9CA3AF",
"text-selected-dark": "#F3F4F6",
"range-start-dark": "#374151",
"range-middle-dark": "#4B5563",
"range-end-dark": "#374151",
"text-in-range-dark": "#E5E7EB",
// Hover effects
"hover-bg": "#9CA3AF",
"hover-bg-dark": "#6B7280",
"hover-text": "#374151",
"hover-text-dark": "#E5E7EB",
// Today's date
"today-bg": "#D1D5DB",
"today-bg-dark": "#4B5563",
"today-text": "#374151",
"today-text-dark": "#D1D5DB",
},
},
},
safelist: [