Compare commits

...

12 Commits

Author SHA1 Message Date
Weves
913a6df440 Add error handling 2025-06-10 17:40:57 -07:00
Weves
44aaf9a494 Add more postgres logging 2025-06-10 14:55:04 -07:00
Weves
8646ed6d76 Fix POSTGRES_IDLE_SESSIONS_TIMEOUT 2025-05-23 11:34:10 -07:00
Weves
dac2e95242 Skip temperature for certain models 2025-05-08 10:46:54 -07:00
Weves
d130b7a2e3 Update LLM requirements 2025-05-08 10:34:47 -07:00
Weves
af164bf308 Add o4-mini support 2025-05-06 10:51:58 -07:00
Weves
b72a2c720b Fix llm access 2025-04-29 10:33:34 -07:00
pablonyx
a8f9dad0c6 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-04-16 09:08:40 -07:00
Weves
7047e77372 Fix startup w/ seed_db 2025-04-08 11:36:08 -07:00
Weves
0ae1c78503 Add more options to dev compose file 2025-04-07 13:50:37 -07:00
Chris Weaver
725f63713c Adjust pg engine intialization (#4408)
* Adjust pg engine intialization

* Fix mypy

* Rename var

* fix typo

* Fix tests
2025-04-07 13:50:26 -07:00
Weves
d737d437c9 Init engine in slackbot 2025-04-07 13:49:52 -07:00
16 changed files with 163 additions and 65 deletions

View File

@@ -1,3 +1,6 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
@@ -44,6 +47,7 @@ from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.main import lifespan as lifespan_base
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -51,6 +55,20 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Small wrapper around the lifespan of the MIT application.
Basically just calls the base lifespan, and then adds EE-only
steps after."""
async with lifespan_base(app):
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
# environment variable. Used to automate deployment for multiple environments.
seed_db()
yield
def get_application() -> FastAPI:
# Anything that happens at import time is not guaranteed to be running ee-version
# Anything after the server startup will be running ee version
@@ -58,7 +76,7 @@ def get_application() -> FastAPI:
test_encryption()
application = get_application_base()
application = get_application_base(lifespan_override=lifespan)
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
@@ -148,10 +166,6 @@ def get_application() -> FastAPI:
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -59,7 +60,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -65,7 +66,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -88,7 +88,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -151,16 +151,26 @@ if LOG_POSTGRES_CONN_COUNTS:
global checkout_count
checkout_count += 1
active_connections = connection_proxy._pool.checkedout()
idle_connections = connection_proxy._pool.checkedin()
pool_size = connection_proxy._pool.size()
logger.debug(
"Connection Checkout\n"
f"Active Connections: {active_connections};\n"
f"Idle: {idle_connections};\n"
f"Pool Size: {pool_size};\n"
f"Total connection checkouts: {checkout_count}"
)
try:
active_connections = connection_proxy._pool.checkedout()
idle_connections = connection_proxy._pool.checkedin()
pool_size = connection_proxy._pool.size()
# Get additional pool information
pool_class_name = connection_proxy._pool.__class__.__name__
engine_app_name = SqlEngine.get_app_name() or "unknown"
logger.debug(
"SYNC Engine Connection Checkout\n"
f"Pool Type: {pool_class_name};\n"
f"App Name: {engine_app_name};\n"
f"Active Connections: {active_connections};\n"
f"Idle Connections: {idle_connections};\n"
f"Pool Size: {pool_size};\n"
f"Total Sync Checkouts: {checkout_count}"
)
except Exception as e:
logger.error(f"Error logging checkout: {e}")
@event.listens_for(Engine, "checkin")
def log_checkin(dbapi_connection, connection_record): # type: ignore
@@ -227,17 +237,62 @@ class SqlEngine:
return engine
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
def init_engine(
cls,
pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int,
**extra_engine_kwargs: Any,
) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database."""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
if cls._engine:
return
connection_string = build_connection_string(
db_api=SYNC_DB_API,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(extra_engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = pool_size
final_engine_kwargs["max_overflow"] = max_overflow
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(extra_engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
cls._engine = engine
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
raise RuntimeError("Engine not initialized. Must call init_engine first.")
return cls._engine
@classmethod
@@ -435,12 +490,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
# NOTE: don't use `text()` here since we're using the cursor directly
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
finally:
cursor.close()

View File

@@ -99,16 +99,18 @@ def _convert_litellm_message_to_langchain_message(
elif role == "assistant":
return AIMessage(
content=content,
tool_calls=[
{
"name": tool_call.function.name or "",
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else [],
tool_calls=(
[
{
"name": tool_call.function.name or "",
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else []
),
)
elif role == "system":
return SystemMessage(content=content)
@@ -409,6 +411,13 @@ class DefaultMultiLLM(LLM):
processed_prompt = _prompt_to_dict(prompt)
self._record_call(processed_prompt)
NO_TEMPERATURE_MODELS = [
"o4-mini",
"o3-mini",
"o3",
"o3-preview",
]
try:
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
@@ -428,9 +437,13 @@ class DefaultMultiLLM(LLM):
# streaming choice
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
**(
{"temperature": self._temperature}
if self.config.model_name not in NO_TEMPERATURE_MODELS
else {}
),
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -439,6 +452,7 @@ class DefaultMultiLLM(LLM):
if tools
and self.config.model_name
not in [
"o4-mini",
"o3-mini",
"o3-preview",
"o1",

View File

@@ -27,10 +27,13 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o4-mini",
"o3-mini",
"o1-mini",
"o3",
"o1",
"gpt-4",
"gpt-4.1",
"gpt-4o",
"gpt-4o-mini",
"o1-preview",

View File

@@ -19,6 +19,7 @@ from httpx_oauth.clients.google import GoogleOAuth2
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
from starlette.types import Lifespan
from onyx import __version__
from onyx.auth.schemas import UserCreate
@@ -264,8 +265,12 @@ def log_http_error(request: Request, exc: Exception) -> JSONResponse:
)
def get_application() -> FastAPI:
application = FastAPI(title="Onyx Backend", version=__version__, lifespan=lifespan)
def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
application = FastAPI(
title="Onyx Backend",
version=__version__,
lifespan=lifespan_override or lifespan,
)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,

View File

@@ -39,6 +39,7 @@ from onyx.context.search.retrieval.search_runner import (
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import SlackBot
from onyx.db.search_settings import get_current_search_settings
from onyx.db.slack_bot import fetch_slack_bots
@@ -934,6 +935,9 @@ def _get_socket_client(
if __name__ == "__main__":
# Initialize the SqlEngine
SqlEngine.init_engine(pool_size=20, max_overflow=5)
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting SlackbotHandler")
tenant_handler = SlackbotHandler()

View File

@@ -38,7 +38,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.63.8
litellm==1.66.3
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@@ -47,7 +47,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.66.3
openai==1.75.0
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5

View File

@@ -4,6 +4,7 @@ import pytest
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -48,6 +49,15 @@ instantiate the session directly within the test.
# yield session
@pytest.fixture(scope="session", autouse=True)
def initialize_db() -> None:
# Make sure that the db engine is initialized before any tests are run
SqlEngine.init_engine(
pool_size=10,
max_overflow=5,
)
@pytest.fixture
def vespa_client() -> vespa_fixture:
with get_session_context_manager() as db_session:

View File

@@ -63,6 +63,10 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
- POSTGRES_API_SERVER_POOL_SIZE=${POSTGRES_API_SERVER_POOL_SIZE:-}
- POSTGRES_API_SERVER_POOL_OVERFLOW=${POSTGRES_API_SERVER_POOL_OVERFLOW:-}
- POSTGRES_IDLE_SESSIONS_TIMEOUT=${POSTGRES_IDLE_SESSIONS_TIMEOUT:-}
- POSTGRES_POOL_RECYCLE=${POSTGRES_POOL_RECYCLE:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose

View File

@@ -985,11 +985,6 @@ export function AssistantEditor({
)
: null
}
requiresImageGeneration={
imageGenerationTool
? values.enabled_tools_map[imageGenerationTool.id]
: false
}
onSelect={(selected) => {
if (selected === null) {
setFieldValue("llm_model_version_override", null);

View File

@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
import { unstable_noStore as noStore } from "next/cache";
import { fetchChatData } from "@/lib/chat/fetchChatData";
import { ChatProvider } from "@/components/context/ChatContext";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
export default async function Layout({
children,
@@ -41,7 +40,6 @@ export default async function Layout({
return (
<>
<InstantSSRAutoRefresh />
<ChatProvider
value={{
proSearchToggled,

View File

@@ -365,7 +365,6 @@ export function UserSettingsModal({
)
: null
}
requiresImageGeneration={false}
onSelect={(selected) => {
if (selected === null) {
handleChangedefaultModel(null);

View File

@@ -22,7 +22,6 @@ interface LLMSelectorProps {
llmProviders: LLMProviderDescriptor[];
currentLlm: string | null;
onSelect: (value: string | null) => void;
requiresImageGeneration?: boolean;
}
export const LLMSelector: React.FC<LLMSelectorProps> = ({
@@ -30,7 +29,6 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
llmProviders,
currentLlm,
onSelect,
requiresImageGeneration,
}) => {
const seenModelNames = new Set();
@@ -90,19 +88,14 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
)}
</SelectItem>
{llmOptions.map((option) => {
if (
!requiresImageGeneration ||
checkLLMSupportsImageInput(option.name)
) {
return (
<SelectItem key={option.value} value={option.value}>
<div className="my-1 flex items-center">
{option.icon && option.icon({ size: 16 })}
<span className="ml-2">{option.name}</span>
</div>
</SelectItem>
);
}
return (
<SelectItem key={option.value} value={option.value}>
<div className="my-1 flex items-center">
{option.icon && option.icon({ size: 16 })}
<span className="ml-2">{option.name}</span>
</div>
</SelectItem>
);
})}
</SelectContent>
</Select>