mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
12 Commits
debug_logg
...
pool-size-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
913a6df440 | ||
|
|
44aaf9a494 | ||
|
|
8646ed6d76 | ||
|
|
dac2e95242 | ||
|
|
d130b7a2e3 | ||
|
|
af164bf308 | ||
|
|
b72a2c720b | ||
|
|
a8f9dad0c6 | ||
|
|
7047e77372 | ||
|
|
0ae1c78503 | ||
|
|
725f63713c | ||
|
|
d737d437c9 |
@@ -1,3 +1,6 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
@@ -44,6 +47,7 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.main import lifespan as lifespan_base
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -51,6 +55,20 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Small wrapper around the lifespan of the MIT application.
|
||||
Basically just calls the base lifespan, and then adds EE-only
|
||||
steps after."""
|
||||
|
||||
async with lifespan_base(app):
|
||||
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
# Anything that happens at import time is not guaranteed to be running ee-version
|
||||
# Anything after the server startup will be running ee version
|
||||
@@ -58,7 +76,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
test_encryption()
|
||||
|
||||
application = get_application_base()
|
||||
application = get_application_base(lifespan_override=lifespan)
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
@@ -148,10 +166,6 @@ def get_application() -> FastAPI:
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
# seed the Onyx environment with LLMs, Assistants, etc. based on an optional
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
@@ -65,17 +65,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
try:
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
|
||||
"This is not a critical error and the model server will continue to run."
|
||||
)
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
|
||||
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
|
||||
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -59,7 +60,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -65,7 +66,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -88,7 +88,8 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import fast_gpu_status_request
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,9 +88,7 @@ class Answer:
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = (
|
||||
fast_gpu_status_request(indexing=False) or using_cloud_reranking
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
|
||||
@@ -420,9 +420,6 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
# Slack specific configs
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -114,7 +114,6 @@ class ConfluenceConnector(
|
||||
self.timezone_offset = timezone_offset
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@@ -159,9 +158,6 @@ class ConfluenceConnector(
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -237,9 +233,7 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
@@ -270,7 +264,6 @@ class ConfluenceConnector(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
if result and result.text:
|
||||
@@ -311,14 +304,13 @@ class ConfluenceConnector(
|
||||
if "version" in page and "by" in page["version"]:
|
||||
author = page["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners.append(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
)
|
||||
primary_owners.append(BasicExpertInfo(display_name=display_name))
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=page_url,
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
),
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
@@ -381,7 +373,6 @@ class ConfluenceConnector(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
@@ -112,7 +112,6 @@ def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
allow_images: bool,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""
|
||||
Processes a Confluence attachment. If it's a document, extracts text,
|
||||
@@ -120,7 +119,7 @@ def process_attachment(
|
||||
"""
|
||||
try:
|
||||
# Get the media type from the attachment metadata
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
# Validate the attachment type
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return AttachmentProcessingResult(
|
||||
@@ -139,14 +138,7 @@ def process_attachment(
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
if not allow_images:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error="Image downloading is not enabled",
|
||||
)
|
||||
else:
|
||||
if not media_type.startswith("image/"):
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
@@ -302,7 +294,6 @@ def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
allow_images: bool,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
@@ -318,7 +309,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(confluence_client, attachment, page_id, allow_images)
|
||||
result = process_attachment(confluence_client, attachment, page_id)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -185,8 +184,6 @@ def instantiate_connector(
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,6 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
allow_images: bool,
|
||||
file: dict[str, Any],
|
||||
) -> Document | ConnectorFailure | None:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
@@ -102,7 +101,6 @@ def _convert_single_file(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
allow_images=allow_images,
|
||||
)
|
||||
|
||||
|
||||
@@ -236,10 +234,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
|
||||
self._retrieved_ids: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
self.allow_images = value
|
||||
|
||||
@property
|
||||
def primary_admin_email(self) -> str:
|
||||
@@ -906,7 +900,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
@@ -1104,9 +1097,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_service.files().list(pageSize=1, fields="files(id)").execute()
|
||||
|
||||
if isinstance(self._creds, ServiceAccountCredentials):
|
||||
# default is ~17mins of retries, don't do that here since this is called from
|
||||
# the UI
|
||||
retry_builder(tries=3, delay=0.1)(get_root_folder_id)(drive_service)
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
|
||||
except HttpError as e:
|
||||
status_code = e.resp.status if e.resp else None
|
||||
|
||||
@@ -79,7 +79,6 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
def _extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
allow_images: bool,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
@@ -88,10 +87,6 @@ def _extract_sections_basic(
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
try:
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
@@ -212,7 +207,6 @@ def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: Callable[[], GoogleDriveService],
|
||||
docs_service: Callable[[], GoogleDocsService],
|
||||
allow_images: bool,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
@@ -242,7 +236,7 @@ def convert_drive_item_to_document(
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
if not sections:
|
||||
sections = _extract_sections_basic(file, drive_service(), allow_images)
|
||||
sections = _extract_sections_basic(file, drive_service())
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
@@ -37,12 +36,12 @@ def _generate_time_range_filter(
|
||||
) -> str:
|
||||
time_range_filter = ""
|
||||
if start is not None:
|
||||
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
|
||||
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
|
||||
time_range_filter += (
|
||||
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
|
||||
)
|
||||
if end is not None:
|
||||
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
|
||||
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
|
||||
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
|
||||
return time_range_filter
|
||||
|
||||
|
||||
@@ -17,12 +17,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
#
|
||||
# NOTE: We previously tried to combat this here by adding a very
|
||||
# long retry period (~20 minutes of trying, one request a minute.)
|
||||
# This is no longer necessary due to checkpointing.
|
||||
add_retries = retry_builder(tries=5, max_delay=10)
|
||||
# extended period of time. Trying to combat here by adding a very
|
||||
# long retry period (~20 minutes of trying every minute)
|
||||
add_retries = retry_builder(tries=50, max_delay=30)
|
||||
|
||||
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
|
||||
PAGE_TOKEN_KEY = "pageToken"
|
||||
@@ -40,14 +37,14 @@ class GoogleFields(str, Enum):
|
||||
|
||||
|
||||
def _execute_with_retry(request: Any) -> Any:
|
||||
max_attempts = 6
|
||||
max_attempts = 10
|
||||
attempt = 1
|
||||
|
||||
while attempt < max_attempts:
|
||||
# Note for reasons unknown, the Google API will sometimes return a 429
|
||||
# and even after waiting the retry period, it will return another 429.
|
||||
# It could be due to a few possibilities:
|
||||
# 1. Other things are also requesting from the Drive/Gmail API with the same key
|
||||
# 1. Other things are also requesting from the Gmail API with the same key
|
||||
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
|
||||
# 3. The retry-after has a maximum and we've already hit the limit for the day
|
||||
# or it's something else...
|
||||
|
||||
@@ -60,10 +60,6 @@ class BaseConnector(abc.ABC, Generic[CT]):
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
def set_allow_images(self, value: bool) -> None:
|
||||
"""Implement if the underlying connector wants to skip/allow image downloading
|
||||
based on the application level image analysis setting."""
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import SLACK_NUM_THREADS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@@ -487,6 +486,7 @@ def _process_message(
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
MAX_WORKERS = 2
|
||||
FAST_TIMEOUT = 1
|
||||
|
||||
def __init__(
|
||||
@@ -496,12 +496,10 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
self.fast_client: WebClient | None = None
|
||||
# just used for efficiency
|
||||
@@ -595,7 +593,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor:
|
||||
futures: list[Future[ProcessedSlackMessage]] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
@@ -706,28 +704,25 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
)
|
||||
|
||||
# 3) If channels are specified and regex is not enabled, verify each is accessible
|
||||
# NOTE: removed this for now since it may be too slow for large workspaces which may
|
||||
# have some automations which create a lot of channels (100k+)
|
||||
if self.channels and not self.channel_regex_enabled:
|
||||
accessible_channels = get_channels(
|
||||
client=self.fast_client,
|
||||
exclude_archived=True,
|
||||
get_public=True,
|
||||
get_private=True,
|
||||
)
|
||||
# For quick lookups by name or ID, build a map:
|
||||
accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
|
||||
# if self.channels and not self.channel_regex_enabled:
|
||||
# accessible_channels = get_channels(
|
||||
# client=self.fast_client,
|
||||
# exclude_archived=True,
|
||||
# get_public=True,
|
||||
# get_private=True,
|
||||
# )
|
||||
# # For quick lookups by name or ID, build a map:
|
||||
# accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
# accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
|
||||
# for user_channel in self.channels:
|
||||
# if (
|
||||
# user_channel not in accessible_channel_names
|
||||
# and user_channel not in accessible_channel_ids
|
||||
# ):
|
||||
# raise ConnectorValidationError(
|
||||
# f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
# )
|
||||
for user_channel in self.channels:
|
||||
if (
|
||||
user_channel not in accessible_channel_names
|
||||
and user_channel not in accessible_channel_ids
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
)
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = e.response.get("error", "")
|
||||
|
||||
@@ -151,16 +151,26 @@ if LOG_POSTGRES_CONN_COUNTS:
|
||||
global checkout_count
|
||||
checkout_count += 1
|
||||
|
||||
active_connections = connection_proxy._pool.checkedout()
|
||||
idle_connections = connection_proxy._pool.checkedin()
|
||||
pool_size = connection_proxy._pool.size()
|
||||
logger.debug(
|
||||
"Connection Checkout\n"
|
||||
f"Active Connections: {active_connections};\n"
|
||||
f"Idle: {idle_connections};\n"
|
||||
f"Pool Size: {pool_size};\n"
|
||||
f"Total connection checkouts: {checkout_count}"
|
||||
)
|
||||
try:
|
||||
active_connections = connection_proxy._pool.checkedout()
|
||||
idle_connections = connection_proxy._pool.checkedin()
|
||||
pool_size = connection_proxy._pool.size()
|
||||
|
||||
# Get additional pool information
|
||||
pool_class_name = connection_proxy._pool.__class__.__name__
|
||||
engine_app_name = SqlEngine.get_app_name() or "unknown"
|
||||
|
||||
logger.debug(
|
||||
"SYNC Engine Connection Checkout\n"
|
||||
f"Pool Type: {pool_class_name};\n"
|
||||
f"App Name: {engine_app_name};\n"
|
||||
f"Active Connections: {active_connections};\n"
|
||||
f"Idle Connections: {idle_connections};\n"
|
||||
f"Pool Size: {pool_size};\n"
|
||||
f"Total Sync Checkouts: {checkout_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging checkout: {e}")
|
||||
|
||||
@event.listens_for(Engine, "checkin")
|
||||
def log_checkin(dbapi_connection, connection_record): # type: ignore
|
||||
@@ -227,17 +237,62 @@ class SqlEngine:
|
||||
return engine
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
def init_engine(
|
||||
cls,
|
||||
pool_size: int,
|
||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||
max_overflow: int,
|
||||
**extra_engine_kwargs: Any,
|
||||
) -> None:
|
||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||
important args, and if incorrectly specified, we have run into hitting the pool
|
||||
limit / using too many connections and overwhelming the database."""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
if cls._engine:
|
||||
return
|
||||
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API,
|
||||
app_name=cls._app_name + "_sync",
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = pool_size
|
||||
final_engine_kwargs["max_overflow"] = max_overflow
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine()
|
||||
raise RuntimeError("Engine not initialized. Must call init_engine first.")
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
@@ -435,12 +490,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
# NOTE: don't use `text()` here since we're using the cursor directly
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -99,16 +99,18 @@ def _convert_litellm_message_to_langchain_message(
|
||||
elif role == "assistant":
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_call.function.name or "",
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
tool_calls=(
|
||||
[
|
||||
{
|
||||
"name": tool_call.function.name or "",
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else []
|
||||
),
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
@@ -409,6 +411,13 @@ class DefaultMultiLLM(LLM):
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
NO_TEMPERATURE_MODELS = [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o3",
|
||||
"o3-preview",
|
||||
]
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
@@ -428,9 +437,13 @@ class DefaultMultiLLM(LLM):
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
**(
|
||||
{"temperature": self._temperature}
|
||||
if self.config.model_name not in NO_TEMPERATURE_MODELS
|
||||
else {}
|
||||
),
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -439,6 +452,7 @@ class DefaultMultiLLM(LLM):
|
||||
if tools
|
||||
and self.config.model_name
|
||||
not in [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o3-preview",
|
||||
"o1",
|
||||
|
||||
@@ -27,10 +27,13 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o3",
|
||||
"o1",
|
||||
"gpt-4",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-preview",
|
||||
|
||||
@@ -19,6 +19,7 @@ from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.types import Lifespan
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.auth.schemas import UserCreate
|
||||
@@ -264,8 +265,12 @@ def log_http_error(request: Request, exc: Exception) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Onyx Backend", version=__version__, lifespan=lifespan)
|
||||
def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
application = FastAPI(
|
||||
title="Onyx Backend",
|
||||
version=__version__,
|
||||
lifespan=lifespan_override or lifespan,
|
||||
)
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
|
||||
@@ -39,9 +39,9 @@ from onyx.context.search.retrieval.search_runner import (
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.slack_bot import fetch_slack_bot
|
||||
from onyx.db.slack_bot import fetch_slack_bots
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
@@ -520,25 +520,6 @@ class SlackbotHandler:
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
|
||||
# skip cases where the bot is disabled in the web UI
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_bot = fetch_slack_bot(
|
||||
db_session=db_session, slack_bot_id=client.slack_bot_id
|
||||
)
|
||||
if not slack_bot:
|
||||
logger.error(
|
||||
f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request."
|
||||
)
|
||||
return False
|
||||
|
||||
if not slack_bot.enabled:
|
||||
logger.info(
|
||||
f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request."
|
||||
)
|
||||
return False
|
||||
|
||||
if req.type == "events_api":
|
||||
# Verify channel is valid
|
||||
event = cast(dict[str, Any], req.payload.get("event", {}))
|
||||
@@ -954,6 +935,9 @@ def _get_socket_client(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize the SqlEngine
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
# Initialize the tenant handler which will manage tenant connections
|
||||
logger.info("Starting SlackbotHandler")
|
||||
tenant_handler = SlackbotHandler()
|
||||
|
||||
@@ -324,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
logger.info(
|
||||
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
|
||||
)
|
||||
gpu_available = gpu_status_request(indexing=True)
|
||||
gpu_available = gpu_status_request()
|
||||
logger.info(f"GPU available: {gpu_available}")
|
||||
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
@@ -12,7 +10,8 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
|
||||
@retry(tries=5, delay=5)
|
||||
def gpu_status_request(indexing: bool = True) -> bool:
|
||||
if indexing:
|
||||
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
|
||||
else:
|
||||
@@ -29,14 +28,3 @@ def _get_gpu_status_from_model_server(indexing: bool) -> bool:
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
|
||||
raise # Re-raise exception to trigger a retry
|
||||
|
||||
|
||||
@retry(tries=5, delay=5)
|
||||
def gpu_status_request(indexing: bool) -> bool:
|
||||
return _get_gpu_status_from_model_server(indexing)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def fast_gpu_status_request(indexing: bool) -> bool:
|
||||
"""For use in sync flows, where we don't want to retry / we want to cache this."""
|
||||
return gpu_status_request(indexing=indexing)
|
||||
|
||||
@@ -38,7 +38,7 @@ langchainhub==0.1.21
|
||||
langgraph==0.2.72
|
||||
langgraph-checkpoint==2.0.13
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.63.8
|
||||
litellm==1.66.3
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
@@ -47,7 +47,7 @@ msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.66.3
|
||||
openai==1.75.0
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -8,16 +7,15 @@ import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.utils import AttachmentProcessingResult
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
def confluence_connector() -> ConfluenceConnector:
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
space=space,
|
||||
space=os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
)
|
||||
@@ -34,15 +32,14 @@ def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@pytest.mark.skip(reason="Skipping this test")
|
||||
def test_confluence_connector_basic(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
@@ -53,14 +50,15 @@ def test_confluence_connector_basic(
|
||||
|
||||
page_within_a_page_doc: Document | None = None
|
||||
page_doc: Document | None = None
|
||||
txt_doc: Document | None = None
|
||||
|
||||
for doc in doc_batch:
|
||||
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
|
||||
page_doc = doc
|
||||
elif ".txt" in doc.semantic_identifier:
|
||||
txt_doc = doc
|
||||
elif doc.semantic_identifier == "Page Within A Page":
|
||||
page_within_a_page_doc = doc
|
||||
else:
|
||||
pass
|
||||
|
||||
assert page_within_a_page_doc is not None
|
||||
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
|
||||
@@ -81,7 +79,7 @@ def test_confluence_connector_basic(
|
||||
assert page_doc.metadata["labels"] == ["testlabel"]
|
||||
assert page_doc.primary_owners
|
||||
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert len(page_doc.sections) == 2 # page text + attachment text
|
||||
assert len(page_doc.sections) == 1
|
||||
|
||||
page_section = page_doc.sections[0]
|
||||
assert page_section.text == "test123 " + page_within_a_page_text
|
||||
@@ -90,65 +88,13 @@ def test_confluence_connector_basic(
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
|
||||
text_attachment_section = page_doc.sections[1]
|
||||
assert text_attachment_section.text == "small"
|
||||
assert text_attachment_section.link
|
||||
assert text_attachment_section.link.endswith("small-file.txt")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_confluence_connector_skip_images(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(False)
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 8
|
||||
assert sum(len(doc.sections) for doc in doc_batch) == 8
|
||||
|
||||
|
||||
def mock_process_image_attachment(
|
||||
*args: Any, **kwargs: Any
|
||||
) -> AttachmentProcessingResult:
|
||||
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
|
||||
be done as a rule to begin with, but life is not perfect. Fix it later"""
|
||||
|
||||
return AttachmentProcessingResult(
|
||||
text="Hi_text",
|
||||
file_name="Hi_filename",
|
||||
error=None,
|
||||
assert txt_doc is not None
|
||||
assert txt_doc.semantic_identifier == "small-file.txt"
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", ["MI"])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@patch(
|
||||
"onyx.connectors.confluence.utils._process_image_attachment",
|
||||
side_effect=mock_process_image_attachment,
|
||||
)
|
||||
def test_confluence_connector_allow_images(
|
||||
mock_get_api_key: MagicMock,
|
||||
mock_process_image_attachment: MagicMock,
|
||||
confluence_connector: ConfluenceConnector,
|
||||
) -> None:
|
||||
confluence_connector.set_allow_images(True)
|
||||
|
||||
doc_batch_generator = confluence_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 8
|
||||
assert sum(len(doc.sections) for doc in doc_batch) == 12
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
@@ -48,6 +49,15 @@ instantiate the session directly within the test.
|
||||
# yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def initialize_db() -> None:
|
||||
# Make sure that the db engine is initialized before any tests are run
|
||||
SqlEngine.init_engine(
|
||||
pool_size=10,
|
||||
max_overflow=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client() -> vespa_fixture:
|
||||
with get_session_context_manager() as db_session:
|
||||
|
||||
@@ -50,7 +50,7 @@ def answer_instance(
|
||||
mocker: MockerFixture,
|
||||
) -> Answer:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
@@ -400,7 +400,7 @@ def test_no_slow_reranking(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=gpu_enabled,
|
||||
)
|
||||
rerank_settings = (
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
question = config["question"]
|
||||
|
||||
@@ -63,6 +63,10 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
- POSTGRES_API_SERVER_POOL_SIZE=${POSTGRES_API_SERVER_POOL_SIZE:-}
|
||||
- POSTGRES_API_SERVER_POOL_OVERFLOW=${POSTGRES_API_SERVER_POOL_OVERFLOW:-}
|
||||
- POSTGRES_IDLE_SESSIONS_TIMEOUT=${POSTGRES_IDLE_SESSIONS_TIMEOUT:-}
|
||||
- POSTGRES_POOL_RECYCLE=${POSTGRES_POOL_RECYCLE:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
|
||||
@@ -985,11 +985,6 @@ export function AssistantEditor({
|
||||
)
|
||||
: null
|
||||
}
|
||||
requiresImageGeneration={
|
||||
imageGenerationTool
|
||||
? values.enabled_tools_map[imageGenerationTool.id]
|
||||
: false
|
||||
}
|
||||
onSelect={(selected) => {
|
||||
if (selected === null) {
|
||||
setFieldValue("llm_model_version_override", null);
|
||||
|
||||
@@ -6,9 +6,11 @@ import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import Link from "next/link";
|
||||
import { SlackChannelConfigsTable } from "./SlackChannelConfigsTable";
|
||||
import { useSlackBot, useSlackChannelConfigsByBot } from "./hooks";
|
||||
import { ExistingSlackBotForm } from "../SlackBotUpdateForm";
|
||||
import { FiPlusSquare } from "react-icons/fi";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
|
||||
function SlackBotEditPage({
|
||||
@@ -35,11 +37,7 @@ function SlackBotEditPage({
|
||||
} = useSlackChannelConfigsByBot(Number(unwrappedParams["bot-id"]));
|
||||
|
||||
if (isSlackBotLoading || isSlackChannelConfigsLoading) {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-screen">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
);
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (slackBotError || !slackBot) {
|
||||
@@ -69,7 +67,7 @@ function SlackBotEditPage({
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="container mx-auto">
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
<BackButton routerOverride="/admin/bots" />
|
||||
@@ -88,18 +86,8 @@ function SlackBotEditPage({
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page({
|
||||
params,
|
||||
}: {
|
||||
params: Promise<{ "bot-id": string }>;
|
||||
}) {
|
||||
return (
|
||||
<div className="container mx-auto">
|
||||
<SlackBotEditPage params={params} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default SlackBotEditPage;
|
||||
|
||||
@@ -1383,7 +1383,6 @@ export function ChatPage({
|
||||
if (!packet) {
|
||||
continue;
|
||||
}
|
||||
console.log("Packet:", JSON.stringify(packet));
|
||||
|
||||
if (!initialFetchDetails) {
|
||||
if (!Object.hasOwn(packet, "user_message_id")) {
|
||||
@@ -1729,7 +1728,6 @@ export function ChatPage({
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.log("Error:", e);
|
||||
const errorMsg = e.message;
|
||||
upsertToCompleteMessageMap({
|
||||
messages: [
|
||||
@@ -1757,13 +1755,11 @@ export function ChatPage({
|
||||
completeMessageMapOverride: currentMessageMap(completeMessageDetail),
|
||||
});
|
||||
}
|
||||
console.log("Finished streaming");
|
||||
setAgenticGenerating(false);
|
||||
resetRegenerationState(currentSessionId());
|
||||
|
||||
updateChatState("input");
|
||||
if (isNewSession) {
|
||||
console.log("Setting up new session");
|
||||
if (finalMessage) {
|
||||
setSelectedMessageForDocDisplay(finalMessage.message_id);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import { redirect } from "next/navigation";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { ChatProvider } from "@/components/context/ChatContext";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
|
||||
export default async function Layout({
|
||||
children,
|
||||
@@ -41,7 +40,6 @@ export default async function Layout({
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
<ChatProvider
|
||||
value={{
|
||||
proSearchToggled,
|
||||
|
||||
@@ -365,7 +365,6 @@ export function UserSettingsModal({
|
||||
)
|
||||
: null
|
||||
}
|
||||
requiresImageGeneration={false}
|
||||
onSelect={(selected) => {
|
||||
if (selected === null) {
|
||||
handleChangedefaultModel(null);
|
||||
|
||||
@@ -22,7 +22,6 @@ interface LLMSelectorProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
currentLlm: string | null;
|
||||
onSelect: (value: string | null) => void;
|
||||
requiresImageGeneration?: boolean;
|
||||
}
|
||||
|
||||
export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
@@ -30,7 +29,6 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
llmProviders,
|
||||
currentLlm,
|
||||
onSelect,
|
||||
requiresImageGeneration,
|
||||
}) => {
|
||||
const seenModelNames = new Set();
|
||||
|
||||
@@ -90,19 +88,14 @@ export const LLMSelector: React.FC<LLMSelectorProps> = ({
|
||||
)}
|
||||
</SelectItem>
|
||||
{llmOptions.map((option) => {
|
||||
if (
|
||||
!requiresImageGeneration ||
|
||||
checkLLMSupportsImageInput(option.name)
|
||||
) {
|
||||
return (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
<div className="my-1 flex items-center">
|
||||
{option.icon && option.icon({ size: 16 })}
|
||||
<span className="ml-2">{option.name}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<SelectItem key={option.value} value={option.value}>
|
||||
<div className="my-1 flex items-center">
|
||||
{option.icon && option.icon({ size: 16 })}
|
||||
<span className="ml-2">{option.name}</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
);
|
||||
})}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
Reference in New Issue
Block a user