Compare commits

..

12 Commits

Author SHA1 Message Date
Weves
913a6df440 Add error handling 2025-06-10 17:40:57 -07:00
Weves
44aaf9a494 Add more postgres logging 2025-06-10 14:55:04 -07:00
Weves
8646ed6d76 Fix POSTGRES_IDLE_SESSIONS_TIMEOUT 2025-05-23 11:34:10 -07:00
Weves
dac2e95242 Skip temperature for certain models 2025-05-08 10:46:54 -07:00
Weves
d130b7a2e3 Update LLM requirements 2025-05-08 10:34:47 -07:00
Weves
af164bf308 Add o4-mini support 2025-05-06 10:51:58 -07:00
Weves
b72a2c720b Fix llm access 2025-04-29 10:33:34 -07:00
pablonyx
a8f9dad0c6 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-04-16 09:08:40 -07:00
Weves
7047e77372 Fix startup w/ seed_db 2025-04-08 11:36:08 -07:00
Weves
0ae1c78503 Add more options to dev compose file 2025-04-07 13:50:37 -07:00
Chris Weaver
725f63713c Adjust pg engine intialization (#4408)
* Adjust pg engine intialization

* Fix mypy

* Rename var

* fix typo

* Fix tests
2025-04-07 13:50:26 -07:00
Weves
d737d437c9 Init engine in slackbot 2025-04-07 13:49:52 -07:00
35 changed files with 235 additions and 299 deletions

View File

@@ -1,3 +1,6 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import BASE_SCOPES
@@ -44,6 +47,7 @@ from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.main import lifespan as lifespan_base
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -51,6 +55,20 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Small wrapper around the lifespan of the MIT application.
Basically just calls the base lifespan, and then adds EE-only
steps after."""
async with lifespan_base(app):
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
# environment variable. Used to automate deployment for multiple environments.
seed_db()
yield
def get_application() -> FastAPI:
# Anything that happens at import time is not guaranteed to be running ee-version
# Anything after the server startup will be running ee version
@@ -58,7 +76,7 @@ def get_application() -> FastAPI:
test_encryption()
application = get_application_base()
application = get_application_base(lifespan_override=lifespan)
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
@@ -148,10 +166,6 @@ def get_application() -> FastAPI:
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")

View File

@@ -65,17 +65,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
app.state.gpu_type = gpu_type
try:
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
except Exception as e:
logger.warning(
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
"This is not a critical error and the model server will continue to run."
)
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -59,7 +60,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,4 +1,5 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -65,7 +66,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -88,7 +88,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -88,9 +88,7 @@ class Answer:
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = (
fast_gpu_status_request(indexing=False) or using_cloud_reranking
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.

View File

@@ -420,9 +420,6 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)

View File

@@ -114,7 +114,6 @@ class ConfluenceConnector(
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
@@ -159,9 +158,6 @@ class ConfluenceConnector(
"max_backoff_seconds": 60,
}
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -237,9 +233,7 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
page_url = f"{self.wiki_base}{page['_links']['webui']}"
# Get the page content
page_content = extract_text_from_confluence_html(
@@ -270,7 +264,6 @@ class ConfluenceConnector(
self.confluence_client,
attachment,
page_id,
self.allow_images,
)
if result and result.text:
@@ -311,14 +304,13 @@ class ConfluenceConnector(
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
email = author.get("email", "unknown@domain.invalid")
primary_owners.append(
BasicExpertInfo(display_name=display_name, email=email)
)
primary_owners.append(BasicExpertInfo(display_name=display_name))
# Create the document
return Document(
id=page_url,
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -381,7 +373,6 @@ class ConfluenceConnector(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue

View File

@@ -112,7 +112,6 @@ def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
allow_images: bool,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
@@ -120,7 +119,7 @@ def process_attachment(
"""
try:
# Get the media type from the attachment metadata
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
@@ -139,14 +138,7 @@ def process_attachment(
attachment_size = attachment["extensions"]["fileSize"]
if media_type.startswith("image/"):
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_name=None,
error="Image downloading is not enabled",
)
else:
if not media_type.startswith("image/"):
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
@@ -302,7 +294,6 @@ def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
@@ -318,7 +309,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id, allow_images)
result = process_attachment(confluence_client, attachment, page_id)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@@ -185,8 +184,6 @@ def instantiate_connector(
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
return connector

View File

@@ -86,7 +86,6 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
@@ -102,7 +101,6 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
allow_images=allow_images,
)
@@ -236,10 +234,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._retrieved_ids: set[str] = set()
self.allow_images = False
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def primary_admin_email(self) -> str:
@@ -906,7 +900,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
)
# Fetch files in batches
@@ -1104,9 +1097,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_service.files().list(pageSize=1, fields="files(id)").execute()
if isinstance(self._creds, ServiceAccountCredentials):
# default is ~17mins of retries, don't do that here since this is called from
# the UI
retry_builder(tries=3, delay=0.1)(get_root_folder_id)(drive_service)
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
status_code = e.resp.status if e.resp else None

View File

@@ -79,7 +79,6 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
@@ -88,10 +87,6 @@ def _extract_sections_basic(
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
@@ -212,7 +207,6 @@ def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
@@ -242,7 +236,7 @@ def convert_drive_item_to_document(
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service(), allow_images)
sections = _extract_sections_basic(file, drive_service())
# If we still don't have any sections, skip this file
if not sections:

View File

@@ -1,7 +1,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
@@ -37,12 +36,12 @@ def _generate_time_range_filter(
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += (
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
)
if end is not None:
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
return time_range_filter

View File

@@ -17,12 +17,9 @@ logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
#
# NOTE: We previously tried to combat this here by adding a very
# long retry period (~20 minutes of trying, one request a minute.)
# This is no longer necessary due to checkpointing.
add_retries = retry_builder(tries=5, max_delay=10)
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
@@ -40,14 +37,14 @@ class GoogleFields(str, Enum):
def _execute_with_retry(request: Any) -> Any:
max_attempts = 6
max_attempts = 10
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Drive/Gmail API with the same key
# 1. Other things are also requesting from the Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...

View File

@@ -60,10 +60,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
Default is a no-op (always successful).
"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore

View File

@@ -18,7 +18,6 @@ from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import SLACK_NUM_THREADS
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -487,6 +486,7 @@ def _process_message(
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
MAX_WORKERS = 2
FAST_TIMEOUT = 1
def __init__(
@@ -496,12 +496,10 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
num_threads: int = SLACK_NUM_THREADS,
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
self.num_threads = num_threads
self.client: WebClient | None = None
self.fast_client: WebClient | None = None
# just used for efficiency
@@ -595,7 +593,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor:
futures: list[Future[ProcessedSlackMessage]] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
@@ -706,28 +704,25 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
)
# 3) If channels are specified and regex is not enabled, verify each is accessible
# NOTE: removed this for now since it may be too slow for large workspaces which may
# have some automations which create a lot of channels (100k+)
if self.channels and not self.channel_regex_enabled:
accessible_channels = get_channels(
client=self.fast_client,
exclude_archived=True,
get_public=True,
get_private=True,
)
# For quick lookups by name or ID, build a map:
accessible_channel_names = {ch["name"] for ch in accessible_channels}
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
# if self.channels and not self.channel_regex_enabled:
# accessible_channels = get_channels(
# client=self.fast_client,
# exclude_archived=True,
# get_public=True,
# get_private=True,
# )
# # For quick lookups by name or ID, build a map:
# accessible_channel_names = {ch["name"] for ch in accessible_channels}
# accessible_channel_ids = {ch["id"] for ch in accessible_channels}
# for user_channel in self.channels:
# if (
# user_channel not in accessible_channel_names
# and user_channel not in accessible_channel_ids
# ):
# raise ConnectorValidationError(
# f"Channel '{user_channel}' not found or inaccessible in this workspace."
# )
for user_channel in self.channels:
if (
user_channel not in accessible_channel_names
and user_channel not in accessible_channel_ids
):
raise ConnectorValidationError(
f"Channel '{user_channel}' not found or inaccessible in this workspace."
)
except SlackApiError as e:
slack_error = e.response.get("error", "")

View File

@@ -151,16 +151,26 @@ if LOG_POSTGRES_CONN_COUNTS:
global checkout_count
checkout_count += 1
active_connections = connection_proxy._pool.checkedout()
idle_connections = connection_proxy._pool.checkedin()
pool_size = connection_proxy._pool.size()
logger.debug(
"Connection Checkout\n"
f"Active Connections: {active_connections};\n"
f"Idle: {idle_connections};\n"
f"Pool Size: {pool_size};\n"
f"Total connection checkouts: {checkout_count}"
)
try:
active_connections = connection_proxy._pool.checkedout()
idle_connections = connection_proxy._pool.checkedin()
pool_size = connection_proxy._pool.size()
# Get additional pool information
pool_class_name = connection_proxy._pool.__class__.__name__
engine_app_name = SqlEngine.get_app_name() or "unknown"
logger.debug(
"SYNC Engine Connection Checkout\n"
f"Pool Type: {pool_class_name};\n"
f"App Name: {engine_app_name};\n"
f"Active Connections: {active_connections};\n"
f"Idle Connections: {idle_connections};\n"
f"Pool Size: {pool_size};\n"
f"Total Sync Checkouts: {checkout_count}"
)
except Exception as e:
logger.error(f"Error logging checkout: {e}")
@event.listens_for(Engine, "checkin")
def log_checkin(dbapi_connection, connection_record): # type: ignore
@@ -227,17 +237,62 @@ class SqlEngine:
return engine
@classmethod
def init_engine(cls, **engine_kwargs: Any) -> None:
def init_engine(
cls,
pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int,
**extra_engine_kwargs: Any,
) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database."""
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
if cls._engine:
return
connection_string = build_connection_string(
db_api=SYNC_DB_API,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(extra_engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = pool_size
final_engine_kwargs["max_overflow"] = max_overflow
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(extra_engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
cls._engine = engine
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
raise RuntimeError("Engine not initialized. Must call init_engine first.")
return cls._engine
@classmethod
@@ -435,12 +490,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
# NOTE: don't use `text()` here since we're using the cursor directly
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
finally:
cursor.close()

View File

@@ -99,16 +99,18 @@ def _convert_litellm_message_to_langchain_message(
elif role == "assistant":
return AIMessage(
content=content,
tool_calls=[
{
"name": tool_call.function.name or "",
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else [],
tool_calls=(
[
{
"name": tool_call.function.name or "",
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
]
if tool_calls
else []
),
)
elif role == "system":
return SystemMessage(content=content)
@@ -409,6 +411,13 @@ class DefaultMultiLLM(LLM):
processed_prompt = _prompt_to_dict(prompt)
self._record_call(processed_prompt)
NO_TEMPERATURE_MODELS = [
"o4-mini",
"o3-mini",
"o3",
"o3-preview",
]
try:
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
@@ -428,9 +437,13 @@ class DefaultMultiLLM(LLM):
# streaming choice
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
**(
{"temperature": self._temperature}
if self.config.model_name not in NO_TEMPERATURE_MODELS
else {}
),
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -439,6 +452,7 @@ class DefaultMultiLLM(LLM):
if tools
and self.config.model_name
not in [
"o4-mini",
"o3-mini",
"o3-preview",
"o1",

View File

@@ -27,10 +27,13 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o4-mini",
"o3-mini",
"o1-mini",
"o3",
"o1",
"gpt-4",
"gpt-4.1",
"gpt-4o",
"gpt-4o-mini",
"o1-preview",

View File

@@ -19,6 +19,7 @@ from httpx_oauth.clients.google import GoogleOAuth2
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
from starlette.types import Lifespan
from onyx import __version__
from onyx.auth.schemas import UserCreate
@@ -264,8 +265,12 @@ def log_http_error(request: Request, exc: Exception) -> JSONResponse:
)
def get_application() -> FastAPI:
application = FastAPI(title="Onyx Backend", version=__version__, lifespan=lifespan)
def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
application = FastAPI(
title="Onyx Backend",
version=__version__,
lifespan=lifespan_override or lifespan,
)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,

View File

@@ -39,9 +39,9 @@ from onyx.context.search.retrieval.search_runner import (
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import SlackBot
from onyx.db.search_settings import get_current_search_settings
from onyx.db.slack_bot import fetch_slack_bot
from onyx.db.slack_bot import fetch_slack_bots
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -520,25 +520,6 @@ class SlackbotHandler:
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
# skip cases where the bot is disabled in the web UI
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
with get_session_with_current_tenant() as db_session:
slack_bot = fetch_slack_bot(
db_session=db_session, slack_bot_id=client.slack_bot_id
)
if not slack_bot:
logger.error(
f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request."
)
return False
if not slack_bot.enabled:
logger.info(
f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request."
)
return False
if req.type == "events_api":
# Verify channel is valid
event = cast(dict[str, Any], req.payload.get("event", {}))
@@ -954,6 +935,9 @@ def _get_socket_client(
if __name__ == "__main__":
# Initialize the SqlEngine
SqlEngine.init_engine(pool_size=20, max_overflow=5)
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting SlackbotHandler")
tenant_handler = SlackbotHandler()

View File

@@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request(indexing=True)
gpu_available = gpu_status_request()
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)

View File

@@ -1,5 +1,3 @@
from functools import lru_cache
import requests
from retry import retry
@@ -12,7 +10,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool = True) -> bool:
if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else:
@@ -29,14 +28,3 @@ def _get_gpu_status_from_model_server(indexing: bool) -> bool:
except requests.RequestException as e:
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
raise # Re-raise exception to trigger a retry
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool) -> bool:
return _get_gpu_status_from_model_server(indexing)
@lru_cache(maxsize=1)
def fast_gpu_status_request(indexing: bool) -> bool:
"""For use in sync flows, where we don't want to retry / we want to cache this."""
return gpu_status_request(indexing=indexing)

View File

@@ -38,7 +38,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.63.8
litellm==1.66.3
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@@ -47,7 +47,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.66.3
openai==1.75.0
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5

View File

@@ -1,6 +1,5 @@
import os
import time
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -8,16 +7,15 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.utils import AttachmentProcessingResult
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
@pytest.fixture
def confluence_connector(space: str) -> ConfluenceConnector:
def confluence_connector() -> ConfluenceConnector:
connector = ConfluenceConnector(
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
space=space,
space=os.environ["CONFLUENCE_TEST_SPACE"],
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
)
@@ -34,15 +32,14 @@ def confluence_connector(space: str) -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
@@ -53,14 +50,15 @@ def test_confluence_connector_basic(
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None
for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
elif ".txt" in doc.semantic_identifier:
txt_doc = doc
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc
else:
pass
assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
@@ -81,7 +79,7 @@ def test_confluence_connector_basic(
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
assert len(page_doc.sections) == 2 # page text + attachment text
assert len(page_doc.sections) == 1
page_section = page_doc.sections[0]
assert page_section.text == "test123 " + page_within_a_page_text
@@ -90,65 +88,13 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)
text_attachment_section = page_doc.sections[1]
assert text_attachment_section.text == "small"
assert text_attachment_section.link
assert text_attachment_section.link.endswith("small-file.txt")
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_confluence_connector_skip_images(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 8
def mock_process_image_attachment(
*args: Any, **kwargs: Any
) -> AttachmentProcessingResult:
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
be done as a rule to begin with, but life is not perfect. Fix it later"""
return AttachmentProcessingResult(
text="Hi_text",
file_name="Hi_filename",
error=None,
assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
assert txt_doc.primary_owners
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
assert (
txt_doc.sections[0].link
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
)
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@patch(
"onyx.connectors.confluence.utils._process_image_attachment",
side_effect=mock_process_image_attachment,
)
def test_confluence_connector_allow_images(
mock_get_api_key: MagicMock,
mock_process_image_attachment: MagicMock,
confluence_connector: ConfluenceConnector,
) -> None:
confluence_connector.set_allow_images(True)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 12

View File

@@ -4,6 +4,7 @@ import pytest
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -48,6 +49,15 @@ instantiate the session directly within the test.
# yield session
@pytest.fixture(scope="session", autouse=True)
def initialize_db() -> None:
# Make sure that the db engine is initialized before any tests are run
SqlEngine.init_engine(
pool_size=10,
max_overflow=5,
)
@pytest.fixture
def vespa_client() -> vespa_fixture:
with get_session_context_manager() as db_session:

View File

@@ -50,7 +50,7 @@ def answer_instance(
mocker: MockerFixture,
) -> Answer:
mocker.patch(
"onyx.chat.answer.fast_gpu_status_request",
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
@@ -400,7 +400,7 @@ def test_no_slow_reranking(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.fast_gpu_status_request",
"onyx.chat.answer.gpu_status_request",
return_value=gpu_enabled,
)
rerank_settings = (

View File

@@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.fast_gpu_status_request",
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
question = config["question"]

View File

@@ -63,6 +63,10 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
- POSTGRES_API_SERVER_POOL_SIZE=${POSTGRES_API_SERVER_POOL_SIZE:-}
- POSTGRES_API_SERVER_POOL_OVERFLOW=${POSTGRES_API_SERVER_POOL_OVERFLOW:-}
- POSTGRES_IDLE_SESSIONS_TIMEOUT=${POSTGRES_IDLE_SESSIONS_TIMEOUT:-}
- POSTGRES_POOL_RECYCLE=${POSTGRES_POOL_RECYCLE:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose

View File

@@ -985,11 +985,6 @@ export function AssistantEditor({
)
: null
}
requiresImageGeneration={
imageGenerationTool
? values.enabled_tools_map[imageGenerationTool.id]
: false
}
onSelect={(selected) => {
if (selected === null) {
setFieldValue("llm_model_version_override", null);

View File

@@ -6,9 +6,11 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import { ThreeDotsLoader } from "@/components/Loading";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { usePopup } from "@/components/admin/connectors/Popup";
import Link from "next/link";
import { SlackChannelConfigsTable } from "./SlackChannelConfigsTable";
import { useSlackBot, useSlackChannelConfigsByBot } from "./hooks";
import { ExistingSlackBotForm } from "../SlackBotUpdateForm";
import { FiPlusSquare } from "react-icons/fi";
import { Separator } from "@/components/ui/separator";
function SlackBotEditPage({
@@ -35,11 +37,7 @@ function SlackBotEditPage({
} = useSlackChannelConfigsByBot(Number(unwrappedParams["bot-id"]));
if (isSlackBotLoading || isSlackChannelConfigsLoading) {
return (
<div className="flex justify-center items-center h-screen">
<ThreeDotsLoader />
</div>
);
return <ThreeDotsLoader />;
}
if (slackBotError || !slackBot) {
@@ -69,7 +67,7 @@ function SlackBotEditPage({
}
return (
<>
<div className="container mx-auto">
<InstantSSRAutoRefresh />
<BackButton routerOverride="/admin/bots" />
@@ -88,18 +86,8 @@ function SlackBotEditPage({
setPopup={setPopup}
/>
</div>
</>
);
}
export default function Page({
params,
}: {
params: Promise<{ "bot-id": string }>;
}) {
return (
<div className="container mx-auto">
<SlackBotEditPage params={params} />
</div>
);
}
export default SlackBotEditPage;

View File

@@ -1383,7 +1383,6 @@ export function ChatPage({
if (!packet) {
continue;
}
console.log("Packet:", JSON.stringify(packet));
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
@@ -1729,7 +1728,6 @@ export function ChatPage({
}
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;
upsertToCompleteMessageMap({
messages: [
@@ -1757,13 +1755,11 @@ export function ChatPage({
completeMessageMapOverride: currentMessageMap(completeMessageDetail),
});
}
console.log("Finished streaming");
setAgenticGenerating(false);
resetRegenerationState(currentSessionId());
updateChatState("input");
if (isNewSession) {
console.log("Setting up new session");
if (finalMessage) {
setSelectedMessageForDocDisplay(finalMessage.message_id);
}

View File

@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
import { unstable_noStore as noStore } from "next/cache";
import { fetchChatData } from "@/lib/chat/fetchChatData";
import { ChatProvider } from "@/components/context/ChatContext";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
export default async function Layout({
children,
@@ -41,7 +40,6 @@ export default async function Layout({
return (
<>
<InstantSSRAutoRefresh />
<ChatProvider
value={{
proSearchToggled,

View File

@@ -365,7 +365,6 @@ export function UserSettingsModal({
)
: null
}
requiresImageGeneration={false}
onSelect={(selected) => {
if (selected === null) {
handleChangedefaultModel(null);

View File

@@ -22,7 +22,6 @@ interface LLMSelectorProps {
llmProviders: LLMProviderDescriptor[];
currentLlm: string | null;
onSelect: (value: string | null) => void;
requiresImageGeneration?: boolean;
}
export const LLMSelector: React.FC<LLMSelectorProps> = ({
@@ -30,7 +29,6 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
llmProviders,
currentLlm,
onSelect,
requiresImageGeneration,
}) => {
const seenModelNames = new Set();
@@ -90,19 +88,14 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
)}
</SelectItem>
{llmOptions.map((option) => {
if (
!requiresImageGeneration ||
checkLLMSupportsImageInput(option.name)
) {
return (
<SelectItem key={option.value} value={option.value}>
<div className="my-1 flex items-center">
{option.icon && option.icon({ size: 16 })}
<span className="ml-2">{option.name}</span>
</div>
</SelectItem>
);
}
return (
<SelectItem key={option.value} value={option.value}>
<div className="my-1 flex items-center">
{option.icon && option.icon({ size: 16 })}
<span className="ml-2">{option.name}</span>
</div>
</SelectItem>
);
})}
</SelectContent>
</Select>