mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
12 Commits
v2.2.1
...
pool-size-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
913a6df440 | ||
|
|
44aaf9a494 | ||
|
|
8646ed6d76 | ||
|
|
dac2e95242 | ||
|
|
d130b7a2e3 | ||
|
|
af164bf308 | ||
|
|
b72a2c720b | ||
|
|
a8f9dad0c6 | ||
|
|
7047e77372 | ||
|
|
0ae1c78503 | ||
|
|
725f63713c | ||
|
|
d737d437c9 |
@@ -1,3 +1,6 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
@@ -44,6 +47,7 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.main import lifespan as lifespan_base
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -51,6 +55,20 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Small wrapper around the lifespan of the MIT application.
|
||||
Basically just calls the base lifespan, and then adds EE-only
|
||||
steps after."""
|
||||
|
||||
async with lifespan_base(app):
|
||||
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
# Anything that happens at import time is not guaranteed to be running ee-version
|
||||
# Anything after the server startup will be running ee version
|
||||
@@ -58,7 +76,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
test_encryption()
|
||||
|
||||
application = get_application_base()
|
||||
application = get_application_base(lifespan_override=lifespan)
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
@@ -148,10 +166,6 @@ def get_application() -> FastAPI:
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -59,7 +60,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -65,7 +66,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -88,7 +88,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -151,16 +151,26 @@ if LOG_POSTGRES_CONN_COUNTS:
|
||||
global checkout_count
|
||||
checkout_count += 1
|
||||
|
||||
active_connections = connection_proxy._pool.checkedout()
|
||||
idle_connections = connection_proxy._pool.checkedin()
|
||||
pool_size = connection_proxy._pool.size()
|
||||
logger.debug(
|
||||
"Connection Checkout\n"
|
||||
f"Active Connections: {active_connections};\n"
|
||||
f"Idle: {idle_connections};\n"
|
||||
f"Pool Size: {pool_size};\n"
|
||||
f"Total connection checkouts: {checkout_count}"
|
||||
)
|
||||
try:
|
||||
active_connections = connection_proxy._pool.checkedout()
|
||||
idle_connections = connection_proxy._pool.checkedin()
|
||||
pool_size = connection_proxy._pool.size()
|
||||
|
||||
# Get additional pool information
|
||||
pool_class_name = connection_proxy._pool.__class__.__name__
|
||||
engine_app_name = SqlEngine.get_app_name() or "unknown"
|
||||
|
||||
logger.debug(
|
||||
"SYNC Engine Connection Checkout\n"
|
||||
f"Pool Type: {pool_class_name};\n"
|
||||
f"App Name: {engine_app_name};\n"
|
||||
f"Active Connections: {active_connections};\n"
|
||||
f"Idle Connections: {idle_connections};\n"
|
||||
f"Pool Size: {pool_size};\n"
|
||||
f"Total Sync Checkouts: {checkout_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging checkout: {e}")
|
||||
|
||||
@event.listens_for(Engine, "checkin")
|
||||
def log_checkin(dbapi_connection, connection_record): # type: ignore
|
||||
@@ -227,17 +237,62 @@ class SqlEngine:
|
||||
return engine
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
def init_engine(
|
||||
cls,
|
||||
pool_size: int,
|
||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||
max_overflow: int,
|
||||
**extra_engine_kwargs: Any,
|
||||
) -> None:
|
||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||
important args, and if incorrectly specified, we have run into hitting the pool
|
||||
limit / using too many connections and overwhelming the database."""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
if cls._engine:
|
||||
return
|
||||
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API,
|
||||
app_name=cls._app_name + "_sync",
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = pool_size
|
||||
final_engine_kwargs["max_overflow"] = max_overflow
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine()
|
||||
raise RuntimeError("Engine not initialized. Must call init_engine first.")
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
@@ -435,12 +490,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
# NOTE: don't use `text()` here since we're using the cursor directly
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -99,16 +99,18 @@ def _convert_litellm_message_to_langchain_message(
|
||||
elif role == "assistant":
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_call.function.name or "",
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
tool_calls=(
|
||||
[
|
||||
{
|
||||
"name": tool_call.function.name or "",
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else []
|
||||
),
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
@@ -409,6 +411,13 @@ class DefaultMultiLLM(LLM):
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
NO_TEMPERATURE_MODELS = [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o3",
|
||||
"o3-preview",
|
||||
]
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
@@ -428,9 +437,13 @@ class DefaultMultiLLM(LLM):
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
**(
|
||||
{"temperature": self._temperature}
|
||||
if self.config.model_name not in NO_TEMPERATURE_MODELS
|
||||
else {}
|
||||
),
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -439,6 +452,7 @@ class DefaultMultiLLM(LLM):
|
||||
if tools
|
||||
and self.config.model_name
|
||||
not in [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o3-preview",
|
||||
"o1",
|
||||
|
||||
@@ -27,10 +27,13 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o3",
|
||||
"o1",
|
||||
"gpt-4",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-preview",
|
||||
|
||||
@@ -19,6 +19,7 @@ from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.auth.schemas import UserCreate
|
||||
@@ -264,8 +265,12 @@ def log_http_error(request: Request, exc: Exception) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Onyx Backend", version=__version__, lifespan=lifespan)
|
||||
def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
application = FastAPI(
|
||||
title="Onyx Backend",
|
||||
version=__version__,
|
||||
lifespan=lifespan_override or lifespan,
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
|
||||
@@ -39,6 +39,7 @@ from onyx.context.search.retrieval.search_runner import (
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.slack_bot import fetch_slack_bots
|
||||
@@ -934,6 +935,9 @@ def _get_socket_client(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize the SqlEngine
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
# Initialize the tenant handler which will manage tenant connections
|
||||
logger.info("Starting SlackbotHandler")
|
||||
tenant_handler = SlackbotHandler()
|
||||
|
||||
@@ -38,7 +38,7 @@ langchainhub==0.1.21
|
||||
langgraph==0.2.72
|
||||
langgraph-checkpoint==2.0.13
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.63.8
|
||||
litellm==1.66.3
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
@@ -47,7 +47,7 @@ msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.66.3
|
||||
openai==1.75.0
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
@@ -48,6 +49,15 @@ instantiate the session directly within the test.
|
||||
# yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def initialize_db() -> None:
|
||||
# Make sure that the db engine is initialized before any tests are run
|
||||
SqlEngine.init_engine(
|
||||
pool_size=10,
|
||||
max_overflow=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client() -> vespa_fixture:
|
||||
with get_session_context_manager() as db_session:
|
||||
|
||||
@@ -63,6 +63,10 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
- POSTGRES_API_SERVER_POOL_SIZE=${POSTGRES_API_SERVER_POOL_SIZE:-}
|
||||
- POSTGRES_API_SERVER_POOL_OVERFLOW=${POSTGRES_API_SERVER_POOL_OVERFLOW:-}
|
||||
- POSTGRES_IDLE_SESSIONS_TIMEOUT=${POSTGRES_IDLE_SESSIONS_TIMEOUT:-}
|
||||
- POSTGRES_POOL_RECYCLE=${POSTGRES_POOL_RECYCLE:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
|
||||
@@ -985,11 +985,6 @@ export function AssistantEditor({
|
||||
)
|
||||
: null
|
||||
}
|
||||
requiresImageGeneration={
|
||||
imageGenerationTool
|
||||
? values.enabled_tools_map[imageGenerationTool.id]
|
||||
: false
|
||||
}
|
||||
onSelect={(selected) => {
|
||||
if (selected === null) {
|
||||
setFieldValue("llm_model_version_override", null);
|
||||
|
||||
@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { ChatProvider } from "@/components/context/ChatContext";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
|
||||
export default async function Layout({
|
||||
children,
|
||||
@@ -41,7 +40,6 @@ export default async function Layout({
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
<ChatProvider
|
||||
value={{
|
||||
proSearchToggled,
|
||||
|
||||
@@ -365,7 +365,6 @@ export function UserSettingsModal({
|
||||
)
|
||||
: null
|
||||
}
|
||||
requiresImageGeneration={false}
|
||||
onSelect={(selected) => {
|
||||
if (selected === null) {
|
||||
handleChangedefaultModel(null);
|
||||
|
||||
@@ -22,7 +22,6 @@ interface LLMSelectorProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
currentLlm: string | null;
|
||||
onSelect: (value: string | null) => void;
|
||||
requiresImageGeneration?: boolean;
|
||||
}
|
||||
|
||||
export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
@@ -30,7 +29,6 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
llmProviders,
|
||||
currentLlm,
|
||||
onSelect,
|
||||
requiresImageGeneration,
|
||||
}) => {
|
||||
const seenModelNames = new Set();
|
||||
|
||||
@@ -90,19 +88,14 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
)}
|
||||
</SelectItem>
|
||||
{llmOptions.map((option) => {
|
||||
if (
|
||||
!requiresImageGeneration ||
|
||||
checkLLMSupportsImageInput(option.name)
|
||||
) {
|
||||
return (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
<div className="my-1 flex items-center">
|
||||
{option.icon && option.icon({ size: 16 })}
|
||||
<span className="ml-2">{option.name}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
<div className="my-1 flex items-center">
|
||||
{option.icon && option.icon({ size: 16 })}
|
||||
<span className="ml-2">{option.name}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
Reference in New Issue
Block a user