Compare commits

...

19 Commits

Author SHA1 Message Date
Nik
bb03160e0b fix(metrics): widen Redis exception handler, fix docstring callsite
- Widen outer Redis except to (ImportError, RuntimeError, AttributeError)
  matching the Postgres section for consistency
- Fix set_start_time docstring to reference start_observability()
2026-03-12 13:19:07 -07:00
Nik
dc1da34f95 fix(metrics): address review feedback on admin debug endpoints
- Use lazy PID-aware _get_process() instead of module-level Process()
- Prime cpu_percent() during set_start_time() so first call is accurate
- Return None instead of 0 for uptime_seconds when start_time not set
- Add comments documenting RedisPool singleton behavior and redis-py
  private attribute compatibility
2026-03-12 13:18:08 -07:00
Nik
e730e571ae feat(metrics): add opt-in admin debug endpoints
Add flag-gated admin debug JSON endpoints for live pod inspection.
Requires ENABLE_ADMIN_DEBUG_ENDPOINTS=true and admin auth.

Endpoints:
- /admin/debug/process-info — RSS, VMS, CPU%, threads, uptime
- /admin/debug/pool-state — Postgres + Redis pool state
- /admin/debug/threads — all threads via threading.enumerate()
- /admin/debug/event-loop-lag — current + max lag

Wires admin debug router through setup_app_observability() with
include_router_fn callback for global prefix support.
2026-03-12 13:18:08 -07:00
Nik
e1af8a68f3 fix(metrics): fix collector registration, docstring, and log message
- Move _collector assignment inside try block so it's only set on
  successful registration (prevents silent loss of metrics on reload)
- Document event loop requirement in docstring
- Name asyncio task for debuggability
- Fix misleading "collector registered" log to "started"
2026-03-12 13:17:56 -07:00
Nik
998dbb7f97 fix(metrics): address review feedback on deep profiling
- Clear _previous_snapshot on stop so restart computes fresh baseline
- Check task.done() in idempotency guard for failed task restart
- Handle REGISTRY duplicate registration on hot-reload
- Add explanatory comment on empty describe()
2026-03-12 13:16:25 -07:00
Nik
d8f96e134b feat(metrics): add opt-in deep profiling (tracemalloc + GC + object counting)
Add flag-gated deep profiling that automatically exports memory
allocation sites, GC stats, and object type counts to Prometheus.

Requires ENABLE_DEEP_PROFILING=true (~10-20% allocation overhead).

New metrics: tracemalloc top bytes/count/delta, total bytes,
GC collections/collected/uncollectable per generation, and
object type counts (top N types).
2026-03-12 13:16:25 -07:00
Nik
270966561e fix(metrics): hoist globals, document benign race in _get_process
Move _process/_process_pid declarations above the function. Add
docstring note explaining why the race is benign (worst case is
a harmless double-create of psutil.Process for the same PID).
2026-03-12 13:16:12 -07:00
Nik
bbebf11577 fix(metrics): address review feedback on thread pool instrumentation
- Move _TASKS_SUBMITTED.inc() after super().submit() to avoid
  counting tasks that fail to submit
- Add explanatory comment on empty describe()
- Use lazy PID-aware _get_process() (same pattern as memory_delta)
- Handle hot-reload with try/except on REGISTRY.register()
- Fix "Always-on" wording in METRICS.md (it's conditional on usage)
2026-03-12 13:13:39 -07:00
Nik
468b8a5f4f feat(metrics): add thread pool instrumentation
Add InstrumentedThreadPoolExecutor that wraps submit() to track task
submission, active count, and duration. Also adds a custom Collector
for process-wide thread count via psutil.

New metrics:
- onyx_threadpool_tasks_submitted_total (counter)
- onyx_threadpool_tasks_active (gauge)
- onyx_threadpool_task_duration_seconds (histogram)
- onyx_process_thread_count (gauge)

Modifies threadpool_concurrency.py to use the instrumented executor.
2026-03-12 13:13:39 -07:00
Nik
d3ae032110 fix(metrics): reset module state on probe restart, name asyncio task
- Reset _current_lag/_max_lag alongside gauge reset in
  start_event_loop_lag_probe() to prevent stale max preventing updates
- Add name="onyx-event-loop-lag-probe" to asyncio.create_task()
  for debuggability
2026-03-12 13:13:23 -07:00
Nik
a439c3d74c fix(metrics): address review feedback on event loop lag probe
- Wrap sleep/time inside try block so probe survives transient errors
- Re-raise CancelledError so stop_event_loop_lag_probe() works
- Check task.done() in idempotency guard so failed tasks restart
- Initialize gauges to 0 at startup for first-scrape visibility
- Save/restore _current_lag in tests for isolation
- Remove duplicate Redis Pool section from METRICS.md (rebase artifact)
2026-03-12 13:11:42 -07:00
Nik
49310f9d58 feat(metrics): add event loop lag probe
Add always-on asyncio background task that measures scheduling lag
by comparing expected vs actual wakeup time.

New metrics:
- onyx_api_event_loop_lag_seconds (gauge, current lag)
- onyx_api_event_loop_lag_max_seconds (gauge, max observed)

Configurable via EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS (default 2s).

Add stop_observability() to prometheus_setup.py for async shutdown.
2026-03-12 13:11:42 -07:00
Nik
88b8b79247 feat(metrics): add Redis connection pool metrics
Add always-on Redis pool collector that reports pool state on each
Prometheus scrape via BlockingConnectionPool internals.

New metrics:
- onyx_redis_pool_in_use (gauge by pool)
- onyx_redis_pool_available (gauge by pool)
- onyx_redis_pool_max (gauge by pool)
- onyx_redis_pool_created (gauge by pool)

Add start_observability() to prometheus_setup.py for lifespan-scoped
collector registration.
2026-03-12 13:11:42 -07:00
Nik
754957d696 fix(metrics): handle hot-reload idempotency for Redis pool collector
Catch ValueError from duplicate REGISTRY.register() on Uvicorn
hot-reload (module reimport resets guard but REGISTRY persists).
Add explanatory comment on empty describe().
2026-03-12 13:11:36 -07:00
Nik
7696a68327 feat(metrics): add Redis connection pool metrics
Add always-on Redis pool collector that reports pool state on each
Prometheus scrape via BlockingConnectionPool internals.

New metrics:
- onyx_redis_pool_in_use (gauge by pool)
- onyx_redis_pool_available (gauge by pool)
- onyx_redis_pool_max (gauge by pool)
- onyx_redis_pool_created (gauge by pool)

Add start_observability() to prometheus_setup.py for lifespan-scoped
collector registration.
2026-03-12 13:11:36 -07:00
Nik
0be29c8716 fix(metrics): hoist globals before _get_process, trim METRICS.md table
- Move _process/_process_pid declarations above the function that
  uses them for readability
- Remove start/stop_observability from orchestration table (added in
  later PRs in this stack)
2026-03-12 13:11:24 -07:00
Nik
50330ce852 fix(metrics): address review feedback on memory delta PR
- Use lazy PID-aware psutil.Process via _get_process() to handle
  Uvicorn forked workers correctly (was capturing parent PID)
- Fix prometheus_setup.py docstring to only reference functions
  that exist on this branch
- Add test for psutil error early-return path in middleware
2026-03-09 15:32:26 -07:00
Nik
3a9500e970 fix(metrics): use abs(delta) histogram + shrink counter for Prometheus compat
Negative histogram buckets break histogram_quantile(). Split into:
- abs(delta) in histogram (all positive buckets)
- separate counter for requests where RSS decreased
2026-03-09 11:33:42 -07:00
Nik
09350580a3 feat(metrics): add per-endpoint memory delta tracking
Add always-on RSS delta measurement per API request via psutil.
Two new metrics:
- onyx_api_request_rss_delta_bytes (histogram by handler)
- onyx_api_process_rss_bytes (gauge, current process RSS)

Refactor prometheus_setup.py into a central orchestration module with
setup_app_observability() for app-scoped middleware and future
start/stop_observability() hooks for lifespan-scoped probes.

Update METRICS.md wiring docs to reflect the new orchestration pattern.
2026-03-04 23:05:48 -08:00
16 changed files with 1731 additions and 14 deletions

View File

@@ -357,6 +357,23 @@ POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW = int(
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
# --- Observability ---
EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS = float(
os.environ.get("EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS", "2.0")
)
ENABLE_DEEP_PROFILING = os.environ.get("ENABLE_DEEP_PROFILING", "").lower() == "true"
ENABLE_ADMIN_DEBUG_ENDPOINTS = (
os.environ.get("ENABLE_ADMIN_DEBUG_ENDPOINTS", "").lower() == "true"
)
DEEP_PROFILING_SNAPSHOT_INTERVAL_SECONDS = float(
os.environ.get("DEEP_PROFILING_SNAPSHOT_INTERVAL_SECONDS", "60.0")
)
DEEP_PROFILING_TOP_N_ALLOCATIONS = int(
os.environ.get("DEEP_PROFILING_TOP_N_ALLOCATIONS", "20")
)
DEEP_PROFILING_TOP_N_TYPES = int(os.environ.get("DEEP_PROFILING_TOP_N_TYPES", "30"))
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"

View File

@@ -125,7 +125,10 @@ from onyx.server.manage.web_search.api import (
from onyx.server.metrics.postgres_connection_pool import (
setup_postgres_connection_pool_metrics,
)
from onyx.server.metrics.prometheus_setup import setup_app_observability
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.metrics.prometheus_setup import start_observability
from onyx.server.metrics.prometheus_setup import stop_observability
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
@@ -334,6 +337,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
},
)
# Lifespan-scoped observability (redis pool, etc.).
# All probes/collectors are orchestrated through prometheus_setup.
start_observability()
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)
@@ -387,6 +394,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
stop_periodic_poller()
await stop_observability()
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:
@@ -640,6 +649,14 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
allow_methods=["*"],
allow_headers=["*"],
)
# App-scoped observability (admin debug router, memory delta middleware).
# Must be called after all routers — memory delta builds its route map
# at registration time.
setup_app_observability(
application,
include_router_fn=include_router_with_global_prefix_prepended,
)
if LOG_ENDPOINT_LATENCY:
add_latency_logging_middleware(application, logger)

View File

@@ -0,0 +1,173 @@
"""Admin debug endpoints for live pod inspection.
Provides JSON endpoints for process info, pool state, threads,
and event loop lag. Only included when ENABLE_ADMIN_DEBUG_ENDPOINTS=true.
Requires admin authentication.
"""
import os
import threading
import time
from typing import Any
from typing import cast
import psutil
from fastapi import APIRouter
from fastapi import Depends
from onyx.auth.users import current_admin_user
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(
prefix="/admin/debug",
tags=["debug"],
dependencies=[Depends(current_admin_user)],
)
_start_time: float | None = None
def _get_process() -> psutil.Process:
"""Return a psutil.Process for the *current* PID.
Lazily created and invalidated when PID changes (fork).
"""
global _process, _process_pid
pid = os.getpid()
if _process is None or _process_pid != pid:
_process = psutil.Process(pid)
# Prime cpu_percent() so the first real call returns a
# meaningful value instead of 0.0.
_process.cpu_percent()
_process_pid = pid
return _process
_process: psutil.Process | None = None
_process_pid: int | None = None
def set_start_time() -> None:
"""Capture server startup time. Called from start_observability()."""
global _start_time
if _start_time is None:
_start_time = time.monotonic()
# Warm the process handle so cpu_percent() is primed.
_get_process()
@router.get("/process-info")
def get_process_info() -> dict[str, Any]:
"""Return process-level resource info."""
proc = _get_process()
mem = proc.memory_info()
uptime: float | None = (
round(time.monotonic() - _start_time, 1) if _start_time is not None else None
)
info: dict[str, Any] = {
"rss_bytes": mem.rss,
"vms_bytes": mem.vms,
"cpu_percent": proc.cpu_percent(),
"num_threads": proc.num_threads(),
"uptime_seconds": uptime,
}
# num_fds() is Linux-only; skip gracefully on macOS/Windows
try:
info["num_fds"] = proc.num_fds()
except (AttributeError, psutil.Error):
pass
return info
@router.get("/pool-state")
def get_pool_state() -> dict[str, Any]:
"""Return Postgres + Redis pool state as JSON."""
result: dict[str, Any] = {"postgres": {}, "redis": {}}
# Postgres pools
try:
from onyx.db.engine.sql_engine import SqlEngine
from sqlalchemy.pool import QueuePool
for label, engine in [
("sync", SqlEngine.get_engine()),
("readonly", SqlEngine.get_readonly_engine()),
]:
pool = engine.pool
if isinstance(pool, QueuePool):
result["postgres"][label] = {
"checked_out": pool.checkedout(),
"checked_in": pool.checkedin(),
"overflow": pool.overflow(),
"size": pool.size(),
}
except (ImportError, RuntimeError, AttributeError):
logger.warning("Failed to read postgres pool state", exc_info=True)
result["postgres"]["error"] = "unable to read pool state"
# Redis pools — uses private redis-py attributes (_in_use_connections, etc.)
# because there is no public API for pool statistics. Wrapped per-pool so
# one failure doesn't block the other.
# NOTE: RedisPool is a singleton — RedisPool() returns the existing instance.
# NOTE: _in_use_connections, _available_connections, _created_connections are
# private attrs on BlockingConnectionPool. If redis-py changes these in a
# future version, the per-pool except block catches AttributeError gracefully.
try:
from redis import BlockingConnectionPool
from onyx.redis.redis_pool import RedisPool
pool_instance = RedisPool()
# Replica pool always exists (defaults to same host as primary)
for label, rpool in [
("primary", cast(BlockingConnectionPool, pool_instance._pool)),
("replica", cast(BlockingConnectionPool, pool_instance._replica_pool)),
]:
try:
result["redis"][label] = {
"in_use": len(rpool._in_use_connections),
"available": len(rpool._available_connections),
"max_connections": rpool.max_connections,
"created_connections": rpool._created_connections,
}
except (AttributeError, TypeError):
logger.warning(
"Redis pool %s: unable to read internals — "
"redis-py private API may have changed",
label,
exc_info=True,
)
result["redis"][label] = {"error": "unable to read pool internals"}
except (ImportError, RuntimeError, AttributeError):
logger.warning("Failed to read redis pool state", exc_info=True)
result["redis"]["error"] = "unable to read pool state"
return result
@router.get("/threads")
def get_threads() -> list[dict[str, Any]]:
"""Return all threads via threading.enumerate()."""
return [
{
"name": t.name,
"daemon": t.daemon,
"ident": t.ident,
"alive": t.is_alive(),
}
for t in threading.enumerate()
]
@router.get("/event-loop-lag")
def get_event_loop_lag() -> dict[str, float]:
"""Return current and max event loop lag."""
from onyx.server.metrics.event_loop_lag import get_current_lag
from onyx.server.metrics.event_loop_lag import get_max_lag
return {
"current_lag_seconds": get_current_lag(),
"max_lag_seconds": get_max_lag(),
}

View File

@@ -0,0 +1,256 @@
"""Automated deep profiling via tracemalloc, GC stats, and object counting.
When ENABLE_DEEP_PROFILING is true, this module:
1. Starts tracemalloc with 10-frame depth
2. Periodically snapshots allocations and computes diffs
3. Exports top allocation sites, GC stats, and object type counts to Prometheus
All data flows to /metrics automatically — no manual endpoints needed.
Metrics:
- onyx_tracemalloc_top_bytes: Bytes by top source locations
- onyx_tracemalloc_top_count: Allocation count by top source locations
- onyx_tracemalloc_delta_bytes: Growth since previous snapshot
- onyx_tracemalloc_total_bytes: Total traced memory
- onyx_gc_collections_total: GC collections per generation
- onyx_gc_collected_total: Objects collected per generation
- onyx_gc_uncollectable_total: Uncollectable objects per generation
- onyx_object_type_count: Live object count by type
"""
import asyncio
import gc
import os
import tracemalloc
from collections import Counter
from typing import Any
from prometheus_client.core import CounterMetricFamily
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client.registry import REGISTRY
from onyx.configs.app_configs import DEEP_PROFILING_SNAPSHOT_INTERVAL_SECONDS
from onyx.configs.app_configs import DEEP_PROFILING_TOP_N_ALLOCATIONS
from onyx.configs.app_configs import DEEP_PROFILING_TOP_N_TYPES
from onyx.utils.logger import setup_logger
logger = setup_logger()
_snapshot_task: asyncio.Task[None] | None = None
# Mutable state updated by the periodic snapshot task, read by the collector
_current_top_stats: list[tracemalloc.Statistic] = []
_current_delta_stats: list[tracemalloc.StatisticDiff] = []
_current_total_bytes: int = 0
_current_object_type_counts: list[tuple[str, int]] = []
_previous_snapshot: tracemalloc.Snapshot | None = None
_cwd: str = os.getcwd()
def _strip_path(filename: str) -> str:
"""Convert absolute paths to relative for low-cardinality labels."""
# Strip site-packages prefix
for marker in ("site-packages/", "dist-packages/"):
idx = filename.find(marker)
if idx != -1:
return filename[idx + len(marker) :]
# Strip cwd
if filename.startswith(_cwd):
return filename[len(_cwd) :].lstrip("/")
return filename
async def _snapshot_loop(interval: float) -> None:
"""Periodically take tracemalloc snapshots and compute diffs."""
global _previous_snapshot, _current_top_stats, _current_delta_stats
global _current_total_bytes, _current_object_type_counts
while True:
await asyncio.sleep(interval)
try:
if not tracemalloc.is_tracing():
continue
snapshot = tracemalloc.take_snapshot()
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
tracemalloc.Filter(False, tracemalloc.__file__),
)
)
all_stats = snapshot.statistics("lineno")
_current_top_stats = all_stats[:DEEP_PROFILING_TOP_N_ALLOCATIONS]
if _previous_snapshot is not None:
_current_delta_stats = snapshot.compare_to(
_previous_snapshot, "lineno"
)[:DEEP_PROFILING_TOP_N_ALLOCATIONS]
else:
_current_delta_stats = []
_current_total_bytes = sum(stat.size for stat in all_stats)
_previous_snapshot = snapshot
# Object type counting — done here (amortized by snapshot interval)
# instead of on every /metrics scrape, since gc.get_objects() is O(n)
# over all live objects and holds the GIL.
counts: Counter[str] = Counter()
for obj in gc.get_objects():
counts[type(obj).__name__] += 1
_current_object_type_counts = counts.most_common(DEEP_PROFILING_TOP_N_TYPES)
except Exception:
logger.warning(
"Error in deep profiling snapshot loop, skipping iteration",
exc_info=True,
)
class DeepProfilingCollector(Collector):
"""Exports tracemalloc, GC, and object type metrics on each scrape."""
def collect(self) -> list[Any]:
families: list[Any] = []
# --- tracemalloc allocation sites ---
top_bytes = GaugeMetricFamily(
"onyx_tracemalloc_top_bytes",
"Bytes allocated by top source locations",
labels=["source"],
)
top_count = GaugeMetricFamily(
"onyx_tracemalloc_top_count",
"Allocation count by top source locations",
labels=["source"],
)
for stat in _current_top_stats:
source = (
f"{_strip_path(stat.traceback[0].filename)}:{stat.traceback[0].lineno}"
)
top_bytes.add_metric([source], stat.size)
top_count.add_metric([source], stat.count)
families.extend([top_bytes, top_count])
# --- tracemalloc deltas ---
delta_bytes = GaugeMetricFamily(
"onyx_tracemalloc_delta_bytes",
"Allocation growth since previous snapshot",
labels=["source"],
)
for diff_stat in _current_delta_stats:
if diff_stat.size_diff > 0:
source = f"{_strip_path(diff_stat.traceback[0].filename)}:{diff_stat.traceback[0].lineno}"
delta_bytes.add_metric([source], diff_stat.size_diff)
families.append(delta_bytes)
# --- tracemalloc total ---
total = GaugeMetricFamily(
"onyx_tracemalloc_total_bytes",
"Total traced memory in bytes",
)
total.add_metric([], _current_total_bytes)
families.append(total)
# --- GC stats ---
gc_collections = CounterMetricFamily(
"onyx_gc_collections_total",
"GC collections per generation",
labels=["generation"],
)
gc_collected = CounterMetricFamily(
"onyx_gc_collected_total",
"Objects collected per generation",
labels=["generation"],
)
gc_uncollectable = CounterMetricFamily(
"onyx_gc_uncollectable_total",
"Uncollectable objects per generation",
labels=["generation"],
)
for i, stats in enumerate(gc.get_stats()):
gen = str(i)
gc_collections.add_metric([gen], stats["collections"])
gc_collected.add_metric([gen], stats["collected"])
gc_uncollectable.add_metric([gen], stats["uncollectable"])
families.extend([gc_collections, gc_collected, gc_uncollectable])
# --- Object type counts (cached from snapshot loop) ---
type_count = GaugeMetricFamily(
"onyx_object_type_count",
"Live object count by type",
labels=["type"],
)
for type_name, count in _current_object_type_counts:
type_count.add_metric([type_name], count)
families.append(type_count)
return families
def describe(self) -> list[Any]:
# Return empty to mark this as an "unchecked" collector.
# Prometheus checks describe() vs collect() for consistency;
# returning empty opts out since our metrics are dynamic.
return []
_collector: DeepProfilingCollector | None = None
def start_deep_profiling() -> None:
"""Start tracemalloc and the periodic snapshot task.
Idempotent — safe to call multiple times (e.g. Uvicorn hot-reload).
Must be called from within a running asyncio event loop (uses
``asyncio.create_task``). In production this is called from
``start_observability()`` inside FastAPI's async lifespan.
"""
global _snapshot_task, _collector
if _snapshot_task is not None and not _snapshot_task.done():
return
if not tracemalloc.is_tracing():
tracemalloc.start(10)
logger.info("tracemalloc started with 10-frame depth")
else:
logger.info("tracemalloc already active, reusing existing session")
_snapshot_task = asyncio.create_task(
_snapshot_loop(DEEP_PROFILING_SNAPSHOT_INTERVAL_SECONDS),
name="onyx-deep-profiling-snapshot",
)
if _collector is None:
collector = DeepProfilingCollector()
try:
REGISTRY.register(collector)
_collector = collector
except ValueError:
logger.debug("Deep profiling collector already registered, skipping")
logger.info("Deep profiling started")
async def stop_deep_profiling() -> None:
"""Stop tracemalloc and cancel the snapshot task."""
global _snapshot_task, _previous_snapshot
if _snapshot_task is not None:
_snapshot_task.cancel()
try:
await _snapshot_task
except asyncio.CancelledError:
pass
_snapshot_task = None
# Clear stale snapshot so a restart computes a fresh baseline
# instead of diffing against data from before the stop.
_previous_snapshot = None
if tracemalloc.is_tracing():
tracemalloc.stop()
logger.info("tracemalloc stopped")

View File

@@ -0,0 +1,106 @@
"""Event loop lag probe.
Schedules a periodic asyncio task that measures the delta between
expected and actual wakeup time. If the event loop is blocked by
synchronous code or CPU-bound work, the lag will spike.
Metrics:
- onyx_api_event_loop_lag_seconds: Current measured lag
- onyx_api_event_loop_lag_max_seconds: Max observed lag since start
"""
import asyncio
from prometheus_client import Gauge
from onyx.configs.app_configs import EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS
from onyx.utils.logger import setup_logger
logger = setup_logger()
_LAG: Gauge = Gauge(
"onyx_api_event_loop_lag_seconds",
"Event loop scheduling lag in seconds",
)
_LAG_MAX: Gauge = Gauge(
"onyx_api_event_loop_lag_max_seconds",
"Maximum event loop scheduling lag observed since process start",
)
_probe_task: asyncio.Task[None] | None = None
_current_lag: float = 0.0
_max_lag: float = 0.0
async def _probe_loop(interval: float) -> None:
global _current_lag, _max_lag
loop = asyncio.get_running_loop()
while True:
try:
before = loop.time()
await asyncio.sleep(interval)
after = loop.time()
lag = (after - before) - interval
if lag < 0:
lag = 0.0
_current_lag = lag
_LAG.set(lag)
if lag > _max_lag:
_max_lag = lag
_LAG_MAX.set(_max_lag)
except asyncio.CancelledError:
raise
except Exception:
logger.warning(
"Error in event loop lag probe, skipping iteration",
exc_info=True,
)
def get_current_lag() -> float:
"""Return the last measured lag value."""
return _current_lag
def get_max_lag() -> float:
"""Return the max observed lag since process start."""
return _max_lag
def start_event_loop_lag_probe() -> None:
"""Start the background lag measurement task.
Idempotent — restarts the probe if the previous task finished
or failed (e.g. after an unhandled exception).
"""
global _probe_task, _current_lag, _max_lag
if _probe_task is not None and not _probe_task.done():
return
# Reset module state and gauges so a restart after failure
# computes a fresh baseline (not stale values from the old probe).
_current_lag = 0.0
_max_lag = 0.0
_LAG.set(0.0)
_LAG_MAX.set(0.0)
_probe_task = asyncio.create_task(
_probe_loop(EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS),
name="onyx-event-loop-lag-probe",
)
async def stop_event_loop_lag_probe() -> None:
"""Cancel the background lag measurement task and await cleanup."""
global _probe_task
if _probe_task is not None:
_probe_task.cancel()
try:
await _probe_task
except asyncio.CancelledError:
pass
_probe_task = None

View File

@@ -0,0 +1,131 @@
"""Per-endpoint memory delta middleware.
Measures RSS change before and after each HTTP request, attributing
memory growth to specific route handlers. Uses psutil for a single
syscall per request (sub-microsecond overhead).
Note: RSS is process-wide, so on a server handling concurrent requests
the delta for one request may include allocations from other requests.
This is inherent to the approach — the metric is most useful for
identifying endpoints that *consistently* cause large deltas.
Metrics:
- onyx_api_request_rss_delta_bytes: Histogram of abs(RSS change) per request
- onyx_api_request_rss_shrink_total: Counter of requests where RSS decreased
- onyx_api_process_rss_bytes: Gauge of current process RSS
"""
import os
import re
from collections.abc import Awaitable
from collections.abc import Callable
import psutil
from fastapi import FastAPI
from fastapi import Request
from fastapi.routing import APIRoute
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from starlette.responses import Response
_RSS_DELTA: Histogram = Histogram(
"onyx_api_request_rss_delta_bytes",
"Absolute RSS change in bytes during a single request",
["handler"],
buckets=(
1024,
4096,
16384,
65536,
262144,
1048576,
4194304,
16777216,
),
)
_RSS_SHRINK: Counter = Counter(
"onyx_api_request_rss_shrink_total",
"Requests where RSS decreased (pages freed)",
["handler"],
)
_PROCESS_RSS: Gauge = Gauge(
"onyx_api_process_rss_bytes",
"Current process RSS in bytes",
)
_process: psutil.Process | None = None
_process_pid: int | None = None
def _get_process() -> psutil.Process:
"""Return a psutil.Process for the *current* PID.
We lazily create the Process object and cache it, but invalidate the
cache when the PID changes (e.g. after Uvicorn forks workers).
Module-level ``psutil.Process()`` would capture the *parent's* PID
and report that child's RSS from the wrong process.
"""
global _process, _process_pid
pid = os.getpid()
if _process is None or _process_pid != pid:
_process = psutil.Process(pid)
_process_pid = pid
return _process
def _build_route_map(app: FastAPI) -> list[tuple[re.Pattern[str], str]]:
route_map: list[tuple[re.Pattern[str], str]] = []
for route in app.routes:
if isinstance(route, APIRoute):
route_map.append((route.path_regex, route.path))
return route_map
def _match_route(route_map: list[tuple[re.Pattern[str], str]], path: str) -> str | None:
for pattern, template in route_map:
if pattern.match(path):
return template
return None
def add_memory_delta_middleware(app: FastAPI) -> None:
"""Register middleware that tracks per-endpoint RSS deltas.
Idempotent — safe to call multiple times (e.g. Uvicorn hot-reload).
Builds its own route map to avoid contextvar ordering issues
with the endpoint context middleware.
"""
if getattr(app.state, "_memory_delta_registered", False):
return
app.state._memory_delta_registered = True
route_map = _build_route_map(app)
@app.middleware("http")
async def memory_delta_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
handler = _match_route(route_map, request.url.path) or "unmatched"
try:
rss_before = _get_process().memory_info().rss
except (psutil.Error, OSError):
return await call_next(request)
response = await call_next(request)
try:
rss_after = _get_process().memory_info().rss
delta = rss_after - rss_before
_RSS_DELTA.labels(handler=handler).observe(abs(delta))
if delta < 0:
_RSS_SHRINK.labels(handler=handler).inc()
_PROCESS_RSS.set(rss_after)
except (psutil.Error, OSError):
pass
return response

View File

@@ -1,23 +1,35 @@
"""Prometheus metrics setup for the Onyx API server.
Orchestrates HTTP request instrumentation via ``prometheus-fastapi-instrumentator``:
- Request count, latency histograms, in-progress gauges
- Pool checkout timeout exception handler
- Custom metric callbacks (e.g. slow request counting)
Central orchestration point for ALL metrics and observability.
Functions:
- ``setup_prometheus_metrics(app)`` — HTTP request instrumentation (middleware).
Called from ``get_application()``.
- ``setup_app_observability(app)`` — app-scoped observability (middleware that
must be registered after all routers). Called from ``get_application()``.
SQLAlchemy connection pool metrics are registered separately via
``setup_postgres_connection_pool_metrics`` during application lifespan
(after engines are created).
"""
from collections.abc import Callable
from fastapi import APIRouter
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import default as default_metrics
from sqlalchemy.exc import TimeoutError as SATimeoutError
from starlette.applications import Starlette
from onyx.configs.app_configs import ENABLE_ADMIN_DEBUG_ENDPOINTS
from onyx.configs.app_configs import ENABLE_DEEP_PROFILING
from onyx.server.metrics.per_tenant import per_tenant_request_callback
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
from onyx.server.metrics.slow_requests import slow_request_callback
from onyx.utils.logger import setup_logger
logger = setup_logger()
_EXCLUDED_HANDLERS = [
"/health",
@@ -73,3 +85,76 @@ def setup_prometheus_metrics(app: Starlette) -> None:
instrumentator.add(per_tenant_request_callback)
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
def setup_app_observability(
app: FastAPI,
include_router_fn: Callable[[FastAPI, APIRouter], None] | None = None,
) -> None:
"""Register app-scoped observability components.
Must be called in ``get_application()`` AFTER all routers are registered
(memory delta middleware builds its route map at registration time).
Args:
app: The FastAPI application.
include_router_fn: Callback ``(app, router) -> None`` that includes
a router with the application's global prefix. If ``None``,
falls back to ``app.include_router(router)``.
"""
if ENABLE_ADMIN_DEBUG_ENDPOINTS:
from onyx.server.metrics.admin_debug import router as debug_router
if include_router_fn is not None:
include_router_fn(app, debug_router)
else:
app.include_router(debug_router)
from onyx.server.metrics.memory_delta import add_memory_delta_middleware
add_memory_delta_middleware(app)
def start_observability() -> None:
"""Start lifespan-scoped observability probes and collectors.
Called from ``lifespan()`` after engines/pools are ready.
"""
from onyx.server.metrics.event_loop_lag import start_event_loop_lag_probe
from onyx.server.metrics.redis_connection_pool import (
setup_redis_connection_pool_metrics,
)
from onyx.server.metrics.threadpool import setup_threadpool_metrics
setup_redis_connection_pool_metrics()
setup_threadpool_metrics()
start_event_loop_lag_probe()
if ENABLE_ADMIN_DEBUG_ENDPOINTS:
from onyx.server.metrics.admin_debug import set_start_time
set_start_time()
if ENABLE_DEEP_PROFILING:
from onyx.server.metrics.deep_profiling import start_deep_profiling
start_deep_profiling()
logger.info("Observability metrics started")
async def stop_observability() -> None:
"""Shut down lifespan-scoped observability probes.
Called from ``lifespan()`` after yield, before engine teardown.
"""
from onyx.server.metrics.event_loop_lag import stop_event_loop_lag_probe
await stop_event_loop_lag_probe()
if ENABLE_DEEP_PROFILING:
from onyx.server.metrics.deep_profiling import stop_deep_profiling
await stop_deep_profiling()
logger.info("Observability metrics stopped")

View File

@@ -0,0 +1,123 @@
"""Redis connection pool Prometheus collector.
Reads pool internals from redis.BlockingConnectionPool on each
Prometheus scrape to report utilization metrics.
Metrics:
- onyx_redis_pool_in_use: Currently checked-out connections
- onyx_redis_pool_available: Idle connections in the pool
- onyx_redis_pool_max: Configured max_connections
- onyx_redis_pool_created: Lifetime connections created
"""
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client.registry import REGISTRY
from redis import BlockingConnectionPool
from onyx.utils.logger import setup_logger
logger = setup_logger()
class RedisPoolCollector(Collector):
"""Custom collector that reads BlockingConnectionPool internals on scrape.
NOTE: Uses private redis-py attributes (_in_use_connections,
_available_connections, _created_connections) because there is no
public API for pool statistics. Wrapped in try/except so a redis-py
upgrade changing internals degrades gracefully (metrics go to 0)
instead of crashing every scrape.
"""
def __init__(self) -> None:
self._pools: list[tuple[str, BlockingConnectionPool]] = []
def add_pool(self, label: str, pool: BlockingConnectionPool) -> None:
self._pools.append((label, pool))
def collect(self) -> list[GaugeMetricFamily]:
in_use = GaugeMetricFamily(
"onyx_redis_pool_in_use",
"Currently checked-out Redis connections",
labels=["pool"],
)
available = GaugeMetricFamily(
"onyx_redis_pool_available",
"Idle Redis connections in the pool",
labels=["pool"],
)
max_conns = GaugeMetricFamily(
"onyx_redis_pool_max",
"Configured max Redis connections",
labels=["pool"],
)
created = GaugeMetricFamily(
"onyx_redis_pool_created",
"Lifetime Redis connections created",
labels=["pool"],
)
for label, pool in self._pools:
try:
in_use.add_metric([label], len(pool._in_use_connections))
available.add_metric([label], len(pool._available_connections))
max_conns.add_metric([label], pool.max_connections)
created.add_metric([label], pool._created_connections)
except (AttributeError, TypeError):
# Degrade to zeros so the time series stays visible
# instead of disappearing when internals change.
in_use.add_metric([label], 0)
available.add_metric([label], 0)
max_conns.add_metric([label], 0)
created.add_metric([label], 0)
logger.warning(
"Redis pool %s: falling back to zero metrics — "
"redis-py internals may have changed",
label,
exc_info=True,
)
return [in_use, available, max_conns, created]
def describe(self) -> list[GaugeMetricFamily]:
# Return empty to mark this as an "unchecked" collector.
# Prometheus checks describe() vs collect() for consistency;
# returning empty opts out since our metrics are dynamic.
return []
_redis_collector: RedisPoolCollector | None = None
def setup_redis_connection_pool_metrics() -> None:
"""Register Redis pool metrics using the RedisPool singleton.
Idempotent — safe to call multiple times (e.g. Uvicorn hot-reload).
On hot-reload, the module re-imports and ``_redis_collector`` resets
to ``None``, but the REGISTRY still holds the old collector.
We catch the ``ValueError`` from duplicate registration and update
the module-level reference to the existing collector.
"""
global _redis_collector
if _redis_collector is not None:
return
from onyx.redis.redis_pool import RedisPool
pool_instance = RedisPool()
collector = RedisPoolCollector()
collector.add_pool("primary", pool_instance._pool)
# Replica pool always exists (defaults to same host as primary when
# REDIS_REPLICA_HOST is not set). Still worth monitoring separately
# since it maintains an independent connection pool.
collector.add_pool("replica", pool_instance._replica_pool)
try:
REGISTRY.register(collector)
except ValueError:
# Already registered from a previous module load (Uvicorn reload).
# The old collector still works — just update our reference.
logger.debug("Redis pool collector already registered, skipping")
_redis_collector = collector
logger.info("Registered Redis connection pool metrics")

View File

@@ -0,0 +1,143 @@
"""Thread pool instrumentation.
Provides an InstrumentedThreadPoolExecutor that wraps submit() to
track task submission, active count, and duration. Also exports a
custom Collector for process-wide thread count.
Metrics:
- onyx_threadpool_tasks_submitted_total: Counter of submitted tasks
- onyx_threadpool_tasks_active: Gauge of currently executing tasks
- onyx_threadpool_task_duration_seconds: Histogram of task execution time
- onyx_process_thread_count: Gauge of total process threads (via psutil)
"""
import os
import time
from collections.abc import Callable
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import psutil
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client.registry import REGISTRY
from onyx.utils.logger import setup_logger
logger = setup_logger()
_TASKS_SUBMITTED: Counter = Counter(
"onyx_threadpool_tasks_submitted_total",
"Total tasks submitted to thread pools",
)
_TASKS_ACTIVE: Gauge = Gauge(
"onyx_threadpool_tasks_active",
"Currently executing thread pool tasks",
)
_TASK_DURATION: Histogram = Histogram(
"onyx_threadpool_task_duration_seconds",
"Thread pool task execution duration in seconds",
)
_process: psutil.Process | None = None
_process_pid: int | None = None
def _get_process() -> psutil.Process:
"""Return a psutil.Process for the *current* PID.
Lazily created and invalidated when PID changes (fork).
Not locked — worst case on a benign race is creating two Process
objects for the same PID; one gets discarded. The default
CollectorRegistry serializes collect() calls anyway.
"""
global _process, _process_pid
pid = os.getpid()
if _process is None or _process_pid != pid:
_process = psutil.Process(pid)
_process_pid = pid
return _process
class InstrumentedThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor subclass that records Prometheus metrics."""
def submit(
self,
fn: Callable[..., Any],
/,
*args: Any,
**kwargs: Any,
) -> Future[Any]:
def _wrapped() -> Any:
# _wrapped runs inside the thread pool worker, so both the
# active gauge and the duration timer reflect *execution* time
# only — queue wait time is excluded.
_TASKS_ACTIVE.inc()
start = time.monotonic()
try:
return fn(*args, **kwargs)
finally:
_TASKS_ACTIVE.dec()
_TASK_DURATION.observe(time.monotonic() - start)
# Increment *after* super().submit() so we don't count tasks
# that fail to submit (e.g. pool already shut down).
future = super().submit(_wrapped)
_TASKS_SUBMITTED.inc()
return future
class ThreadCountCollector(Collector):
"""Reports the process-wide thread count on each Prometheus scrape."""
def collect(self) -> list[GaugeMetricFamily]:
family = GaugeMetricFamily(
"onyx_process_thread_count",
"Total OS threads in the process",
)
try:
family.add_metric([], _get_process().num_threads())
except (psutil.Error, OSError):
logger.warning("Failed to read process thread count", exc_info=True)
family.add_metric([], 0)
return [family]
def describe(self) -> list[GaugeMetricFamily]:
# Return empty to mark this as an "unchecked" collector.
# Prometheus checks describe() vs collect() for consistency;
# returning empty opts out since our metrics are dynamic.
return []
_thread_collector: ThreadCountCollector | None = None
def setup_threadpool_metrics() -> None:
"""Register the process thread count collector and enable instrumentation.
Idempotent — safe to call multiple times (e.g. Uvicorn hot-reload).
Uses try/except on REGISTRY.register() to handle the case where the
module is reimported (guard resets) but REGISTRY still holds the old
collector.
"""
global _thread_collector
if _thread_collector is not None:
return
from onyx.utils.threadpool_concurrency import enable_threadpool_instrumentation
enable_threadpool_instrumentation()
collector = ThreadCountCollector()
try:
REGISTRY.register(collector)
except ValueError:
logger.debug("Thread count collector already registered, skipping")
_thread_collector = collector

View File

@@ -30,6 +30,32 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_threadpool_instrumentation_enabled: bool = False
def _get_executor_class() -> type[ThreadPoolExecutor]:
"""Return InstrumentedThreadPoolExecutor when running in the API server.
Non-server contexts (Celery workers, CLI scripts) get vanilla
ThreadPoolExecutor because enable_threadpool_instrumentation() is
never called there. The flag lives here (not in the metrics module)
to avoid importing prometheus_client as a side effect of every
parallel operation.
"""
if _threadpool_instrumentation_enabled:
from onyx.server.metrics.threadpool import InstrumentedThreadPoolExecutor
return InstrumentedThreadPoolExecutor
return ThreadPoolExecutor
def enable_threadpool_instrumentation() -> None:
"""Called by setup_threadpool_metrics() during API server startup."""
global _threadpool_instrumentation_enabled
_threadpool_instrumentation_enabled = True
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
@@ -323,7 +349,8 @@ def run_functions_tuples_in_parallel(
return []
results: list[tuple[int, Any]] = []
executor = ThreadPoolExecutor(max_workers=workers)
executor_cls = _get_executor_class()
executor = executor_cls(max_workers=workers)
try:
# The primary reason for propagating contextvars is to allow acquiring a db session
@@ -421,7 +448,8 @@ def run_functions_in_parallel(
if len(function_calls) == 0:
return results
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
executor_cls = _get_executor_class()
with executor_cls(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(
contextvars.copy_context().run, func_call.execute
@@ -543,7 +571,8 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
if you are consuming all elements from the generators OR it is acceptable
for some extra generator code to run and not have the result(s) yielded.
"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor_cls = _get_executor_class()
with executor_cls(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, ind, gen): ind
for ind, gen in enumerate(gens)

View File

@@ -0,0 +1,136 @@
"""Unit tests for deep profiling metrics collector."""
from unittest.mock import MagicMock
from onyx.server.metrics.deep_profiling import _strip_path
from onyx.server.metrics.deep_profiling import DeepProfilingCollector
def test_strip_path_site_packages() -> None:
"""Verify site-packages prefix is stripped."""
path = "/usr/lib/python3.11/site-packages/onyx/chat/process.py"
assert _strip_path(path) == "onyx/chat/process.py"
def test_strip_path_dist_packages() -> None:
path = "/usr/lib/python3/dist-packages/sqlalchemy/engine.py"
assert _strip_path(path) == "sqlalchemy/engine.py"
def test_strip_path_cwd() -> None:
"""Verify cwd prefix is stripped."""
import os
cwd = os.getcwd()
path = f"{cwd}/onyx/server/main.py"
assert _strip_path(path) == "onyx/server/main.py"
def test_strip_path_unknown_returns_as_is() -> None:
path = "/some/random/path.py"
assert _strip_path(path) == path
def _make_mock_stat(filename: str, lineno: int, size: int, count: int) -> MagicMock:
stat = MagicMock()
frame = MagicMock()
frame.filename = filename
frame.lineno = lineno
stat.traceback = [frame]
stat.size = size
stat.count = count
stat.size_diff = size # For delta stats
return stat
def test_collector_exports_tracemalloc_metrics() -> None:
"""Verify the collector exports top allocation sites."""
import onyx.server.metrics.deep_profiling as mod
original_top = mod._current_top_stats
original_delta = mod._current_delta_stats
original_total = mod._current_total_bytes
try:
mod._current_top_stats = [
_make_mock_stat("site-packages/onyx/chat.py", 42, 1024, 10),
_make_mock_stat("site-packages/onyx/db.py", 100, 2048, 5),
]
mod._current_delta_stats = [
_make_mock_stat("site-packages/onyx/chat.py", 42, 512, 3),
]
mod._current_total_bytes = 3072
collector = DeepProfilingCollector()
families = collector.collect()
# Find specific metric families by name
family_names = [f.name for f in families]
assert "onyx_tracemalloc_top_bytes" in family_names
assert "onyx_tracemalloc_top_count" in family_names
assert "onyx_tracemalloc_delta_bytes" in family_names
assert "onyx_tracemalloc_total_bytes" in family_names
assert "onyx_gc_collections" in family_names
assert "onyx_gc_collected" in family_names
assert "onyx_gc_uncollectable" in family_names
assert "onyx_object_type_count" in family_names
# Verify top_bytes values
top_bytes_family = next(
f for f in families if f.name == "onyx_tracemalloc_top_bytes"
)
values = {s.labels["source"]: s.value for s in top_bytes_family.samples}
assert values["onyx/chat.py:42"] == 1024
assert values["onyx/db.py:100"] == 2048
# Verify total
total_family = next(
f for f in families if f.name == "onyx_tracemalloc_total_bytes"
)
assert total_family.samples[0].value == 3072
finally:
mod._current_top_stats = original_top
mod._current_delta_stats = original_delta
mod._current_total_bytes = original_total
def test_collector_exports_gc_stats() -> None:
"""Verify GC generation stats are exported."""
collector = DeepProfilingCollector()
families = collector.collect()
gc_collections = next(f for f in families if f.name == "onyx_gc_collections")
# Should have 3 generations (0, 1, 2)
assert len(gc_collections.samples) == 3
generations = {s.labels["generation"] for s in gc_collections.samples}
assert generations == {"0", "1", "2"}
def test_collector_exports_object_type_counts() -> None:
"""Verify object type counts are exported from cached snapshot data."""
import onyx.server.metrics.deep_profiling as mod
original = mod._current_object_type_counts
try:
mod._current_object_type_counts = [
("dict", 5000),
("list", 3000),
("tuple", 2000),
]
collector = DeepProfilingCollector()
families = collector.collect()
type_count = next(f for f in families if f.name == "onyx_object_type_count")
assert len(type_count.samples) == 3
values = {s.labels["type"]: s.value for s in type_count.samples}
assert values["dict"] == 5000
assert values["list"] == 3000
finally:
mod._current_object_type_counts = original
def test_collector_describe_returns_empty() -> None:
collector = DeepProfilingCollector()
assert collector.describe() == []

View File

@@ -0,0 +1,69 @@
"""Unit tests for event loop lag probe."""
import asyncio
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.event_loop_lag import _probe_loop
from onyx.server.metrics.event_loop_lag import start_event_loop_lag_probe
from onyx.server.metrics.event_loop_lag import stop_event_loop_lag_probe
@pytest.mark.asyncio
@patch("onyx.server.metrics.event_loop_lag._LAG")
@patch("onyx.server.metrics.event_loop_lag._LAG_MAX")
async def test_probe_measures_lag(
mock_lag_max: MagicMock, # noqa: ARG001
mock_lag: MagicMock,
) -> None:
"""Verify the probe records non-negative lag after sleeping."""
import onyx.server.metrics.event_loop_lag as mod
original_lag = mod._current_lag
original_max = mod._max_lag
mod._current_lag = 0.0
mod._max_lag = 0.0
try:
# Run the probe with a very short interval so it fires quickly
task = asyncio.create_task(_probe_loop(0.01))
await asyncio.sleep(0.05) # Let it fire a few times
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# The lag gauge should have been set at least once
assert mock_lag.set.call_count >= 1
# All observed lag values should be non-negative
for call in mock_lag.set.call_args_list:
assert call[0][0] >= 0.0
finally:
mod._current_lag = original_lag
mod._max_lag = original_max
@pytest.mark.asyncio
async def test_start_stop_lifecycle() -> None:
"""Verify start/stop create and cancel the task."""
import onyx.server.metrics.event_loop_lag as mod
original_task = mod._probe_task
mod._probe_task = None
try:
with patch(
"onyx.server.metrics.event_loop_lag.EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS",
0.01,
):
start_event_loop_lag_probe()
assert mod._probe_task is not None
assert not mod._probe_task.cancelled()
await stop_event_loop_lag_probe()
assert mod._probe_task is None
finally:
mod._probe_task = original_task

View File

@@ -0,0 +1,152 @@
"""Unit tests for per-endpoint memory delta middleware."""
from unittest.mock import MagicMock
from unittest.mock import patch
import psutil
from fastapi import FastAPI
from starlette.testclient import TestClient
from onyx.server.metrics.memory_delta import _build_route_map
from onyx.server.metrics.memory_delta import _match_route
from onyx.server.metrics.memory_delta import add_memory_delta_middleware
def _make_app() -> FastAPI:
app = FastAPI()
@app.get("/api/chat/{chat_id}")
def get_chat(chat_id: str) -> dict[str, str]:
return {"id": chat_id}
@app.get("/api/health")
def health() -> dict[str, str]:
return {"status": "ok"}
return app
def test_build_route_map_extracts_api_routes() -> None:
app = _make_app()
route_map = _build_route_map(app)
templates = [template for _, template in route_map]
assert "/api/chat/{chat_id}" in templates
assert "/api/health" in templates
def test_match_route_returns_template() -> None:
app = _make_app()
route_map = _build_route_map(app)
assert _match_route(route_map, "/api/chat/abc-123") == "/api/chat/{chat_id}"
assert _match_route(route_map, "/api/health") == "/api/health"
assert _match_route(route_map, "/nonexistent") is None
@patch("onyx.server.metrics.memory_delta._get_process")
@patch("onyx.server.metrics.memory_delta._RSS_SHRINK")
@patch("onyx.server.metrics.memory_delta._RSS_DELTA")
@patch("onyx.server.metrics.memory_delta._PROCESS_RSS")
def test_middleware_observes_rss_delta(
mock_rss_gauge: MagicMock,
mock_histogram: MagicMock,
mock_shrink: MagicMock,
mock_get_process: MagicMock,
) -> None:
"""Verify the middleware measures RSS before/after and records abs(delta)."""
mem_before = MagicMock()
mem_before.rss = 100_000_000
mem_after = MagicMock()
mem_after.rss = 100_065_536
mock_proc = MagicMock()
mock_proc.memory_info.side_effect = [mem_before, mem_after]
mock_get_process.return_value = mock_proc
app = _make_app()
add_memory_delta_middleware(app)
client = TestClient(app)
response = client.get("/api/health")
assert response.status_code == 200
mock_histogram.labels.assert_called_with(handler="/api/health")
mock_histogram.labels().observe.assert_called_once_with(65_536)
mock_shrink.labels().inc.assert_not_called()
mock_rss_gauge.set.assert_called_once_with(100_065_536)
@patch("onyx.server.metrics.memory_delta._get_process")
@patch("onyx.server.metrics.memory_delta._RSS_SHRINK")
@patch("onyx.server.metrics.memory_delta._RSS_DELTA")
@patch("onyx.server.metrics.memory_delta._PROCESS_RSS")
def test_middleware_tracks_rss_shrink(
mock_rss_gauge: MagicMock, # noqa: ARG001
mock_histogram: MagicMock,
mock_shrink: MagicMock,
mock_get_process: MagicMock,
) -> None:
"""When RSS decreases, observe abs(delta) and increment shrink counter."""
mem_before = MagicMock()
mem_before.rss = 100_065_536
mem_after = MagicMock()
mem_after.rss = 100_000_000
mock_proc = MagicMock()
mock_proc.memory_info.side_effect = [mem_before, mem_after]
mock_get_process.return_value = mock_proc
app = _make_app()
add_memory_delta_middleware(app)
client = TestClient(app)
client.get("/api/health")
mock_histogram.labels().observe.assert_called_once_with(65_536)
mock_shrink.labels.assert_called_with(handler="/api/health")
mock_shrink.labels().inc.assert_called_once()
@patch("onyx.server.metrics.memory_delta._get_process")
@patch("onyx.server.metrics.memory_delta._RSS_DELTA")
@patch("onyx.server.metrics.memory_delta._PROCESS_RSS")
def test_middleware_uses_unmatched_for_unknown_paths(
mock_rss_gauge: MagicMock, # noqa: ARG001
mock_histogram: MagicMock,
mock_get_process: MagicMock,
) -> None:
mem_info = MagicMock()
mem_info.rss = 50_000_000
mock_proc = MagicMock()
mock_proc.memory_info.return_value = mem_info
mock_get_process.return_value = mock_proc
app = _make_app()
add_memory_delta_middleware(app)
client = TestClient(app, raise_server_exceptions=False)
client.get("/totally-unknown")
mock_histogram.labels.assert_called_with(handler="unmatched")
@patch("onyx.server.metrics.memory_delta._get_process")
@patch("onyx.server.metrics.memory_delta._RSS_DELTA")
@patch("onyx.server.metrics.memory_delta._PROCESS_RSS")
def test_middleware_skips_metrics_on_psutil_error(
mock_rss_gauge: MagicMock, # noqa: ARG001
mock_histogram: MagicMock,
mock_get_process: MagicMock,
) -> None:
"""When psutil raises on the initial memory_info call, middleware skips metrics."""
mock_proc = MagicMock()
mock_proc.memory_info.side_effect = psutil.Error("no such process")
mock_get_process.return_value = mock_proc
app = _make_app()
add_memory_delta_middleware(app)
client = TestClient(app)
response = client.get("/api/health")
assert response.status_code == 200
mock_histogram.labels.assert_not_called()

View File

@@ -0,0 +1,86 @@
"""Unit tests for Redis connection pool metrics collector."""
from unittest.mock import MagicMock
from onyx.server.metrics.redis_connection_pool import RedisPoolCollector
def test_redis_pool_collector_reports_metrics() -> None:
"""Verify the collector reads pool internals correctly."""
mock_pool = MagicMock()
mock_pool._in_use_connections = {"conn1", "conn2", "conn3"}
mock_pool._available_connections = ["conn4", "conn5"]
mock_pool.max_connections = 128
mock_pool._created_connections = 5
collector = RedisPoolCollector()
collector.add_pool("primary", mock_pool)
families = collector.collect()
assert len(families) == 4
metrics: dict[str, float] = {}
for family in families:
for sample in family.samples:
metrics[f"{sample.name}:{sample.labels['pool']}"] = sample.value
assert metrics["onyx_redis_pool_in_use:primary"] == 3
assert metrics["onyx_redis_pool_available:primary"] == 2
assert metrics["onyx_redis_pool_max:primary"] == 128
assert metrics["onyx_redis_pool_created:primary"] == 5
def test_redis_pool_collector_handles_multiple_pools() -> None:
"""Verify the collector supports primary + replica pools."""
primary = MagicMock()
primary._in_use_connections = {"a"}
primary._available_connections = ["b", "c"]
primary.max_connections = 128
primary._created_connections = 3
replica = MagicMock()
replica._in_use_connections = set()
replica._available_connections = ["d"]
replica.max_connections = 64
replica._created_connections = 1
collector = RedisPoolCollector()
collector.add_pool("primary", primary)
collector.add_pool("replica", replica)
families = collector.collect()
metrics: dict[str, float] = {}
for family in families:
for sample in family.samples:
metrics[f"{sample.name}:{sample.labels['pool']}"] = sample.value
assert metrics["onyx_redis_pool_in_use:primary"] == 1
assert metrics["onyx_redis_pool_in_use:replica"] == 0
assert metrics["onyx_redis_pool_max:replica"] == 64
def test_redis_pool_collector_falls_back_to_zeros_on_attribute_error() -> None:
"""Verify collector degrades gracefully when redis-py internals change."""
mock_pool = MagicMock(spec=[]) # empty spec — no attributes at all
collector = RedisPoolCollector()
collector.add_pool("primary", mock_pool)
families = collector.collect()
assert len(families) == 4
metrics: dict[str, float] = {}
for family in families:
for sample in family.samples:
metrics[f"{sample.name}:{sample.labels['pool']}"] = sample.value
# All metrics should fall back to zero
assert metrics["onyx_redis_pool_in_use:primary"] == 0
assert metrics["onyx_redis_pool_available:primary"] == 0
assert metrics["onyx_redis_pool_max:primary"] == 0
assert metrics["onyx_redis_pool_created:primary"] == 0
def test_redis_pool_collector_describe_returns_empty() -> None:
"""Unchecked collector pattern — describe() returns empty."""
collector = RedisPoolCollector()
assert collector.describe() == []

View File

@@ -0,0 +1,70 @@
"""Unit tests for thread pool instrumentation."""
from unittest.mock import patch
import pytest
from onyx.server.metrics.threadpool import InstrumentedThreadPoolExecutor
from onyx.server.metrics.threadpool import ThreadCountCollector
def test_instrumented_executor_tracks_submissions() -> None:
"""Verify counter increments and gauge tracks active tasks."""
with (
patch("onyx.server.metrics.threadpool._TASKS_SUBMITTED") as mock_submitted,
patch("onyx.server.metrics.threadpool._TASKS_ACTIVE") as mock_active,
patch("onyx.server.metrics.threadpool._TASK_DURATION") as mock_duration,
):
with InstrumentedThreadPoolExecutor(max_workers=2) as executor:
future = executor.submit(lambda: 42)
result = future.result(timeout=5)
assert result == 42
mock_submitted.inc.assert_called_once()
mock_active.inc.assert_called_once()
mock_active.dec.assert_called_once()
mock_duration.observe.assert_called_once()
# Duration should be non-negative
observed_duration = mock_duration.observe.call_args[0][0]
assert observed_duration >= 0
def test_instrumented_executor_handles_exceptions() -> None:
"""Verify metrics still fire when the task raises."""
with (
patch("onyx.server.metrics.threadpool._TASKS_SUBMITTED") as mock_submitted,
patch("onyx.server.metrics.threadpool._TASKS_ACTIVE") as mock_active,
patch("onyx.server.metrics.threadpool._TASK_DURATION") as mock_duration,
):
with InstrumentedThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(lambda: 1 / 0)
with pytest.raises(ZeroDivisionError):
future.result(timeout=5)
# Metrics should still be recorded even on failure
mock_submitted.inc.assert_called_once()
mock_active.inc.assert_called_once()
mock_active.dec.assert_called_once()
mock_duration.observe.assert_called_once()
def test_thread_count_collector_reports_threads() -> None:
"""Verify the collector returns the process thread count."""
with patch("onyx.server.metrics.threadpool._get_process") as mock_get_process:
mock_get_process.return_value.num_threads.return_value = 15
collector = ThreadCountCollector()
families = collector.collect()
assert len(families) == 1
samples = families[0].samples
assert len(samples) == 1
assert samples[0].value == 15
def test_thread_count_collector_describe_returns_empty() -> None:
collector = ThreadCountCollector()
assert collector.describe() == []

View File

@@ -57,9 +57,16 @@ from onyx.server.metrics.my_metric import my_metric_callback
instrumentator.add(my_metric_callback)
```
### 4. Wire it into setup_prometheus_metrics (if infrastructure-scoped)
### 4. Wire it into the orchestration layer (if infrastructure-scoped)
For metrics that attach to engines, pools, or background systems, add a setup function and call it from `setup_prometheus_metrics()` in `metrics/prometheus_setup.py`:
For metrics that attach to engines, pools, or background systems, add a setup function and call it from the appropriate orchestration function in `metrics/prometheus_setup.py`:
| Function | Called from | Purpose |
|----------|-------------|---------|
| `setup_prometheus_metrics(app)` | `get_application()` | HTTP request instrumentation (middleware) |
| `setup_app_observability(app)` | `get_application()` | App-scoped components (middleware registered after routers) |
For lifespan-scoped metrics (probes, collectors that need engines/pools ready), add a setup function and call it from `start_observability()` in `metrics/prometheus_setup.py`:
```python
# metrics/my_metric.py
@@ -69,15 +76,13 @@ def setup_my_metrics(resource: SomeResource) -> None:
```
```python
# metrics/prometheus_setup.py — inside setup_prometheus_metrics()
# metrics/prometheus_setup.py — inside start_observability()
from onyx.server.metrics.my_metric import setup_my_metrics
def setup_prometheus_metrics(app, engines=None) -> None:
setup_my_metrics(resource) # Add your call here
...
setup_my_metrics(resource)
```
All metrics initialization is funneled through the single `setup_prometheus_metrics()` call in `onyx/main.py:lifespan()`. Do not add separate setup calls to `main.py`.
All metrics initialization is funneled through `metrics/prometheus_setup.py`. Do not add separate setup calls to `main.py`.
### 5. Write tests
@@ -169,6 +174,125 @@ Engine label values: `sync` (main read-write), `async` (async sessions), `readon
Connections from background tasks (Celery) or boot-time warmup appear as `handler="unknown"`.
## Memory Metrics
Always-on, sub-microsecond overhead per request (single `psutil` syscall).
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_api_request_rss_delta_bytes` | Histogram | `handler` | Absolute RSS change in bytes during a request |
| `onyx_api_request_rss_shrink_total` | Counter | `handler` | Requests where RSS decreased (pages freed) |
| `onyx_api_process_rss_bytes` | Gauge | — | Current process RSS |
The histogram tracks `abs(delta)` so `histogram_quantile()` works correctly.
Use the shrink counter to distinguish growth from reclamation.
```promql
# Top 5 endpoints by average memory impact per request
topk(5, avg by (handler)(
rate(onyx_api_request_rss_delta_bytes_sum[5m])
/ rate(onyx_api_request_rss_delta_bytes_count[5m])
))
# Endpoints with frequent RSS shrinkage (GC/mmap release)
topk(5, rate(onyx_api_request_rss_shrink_total[5m]))
```
## Redis Pool Metrics
Read from `BlockingConnectionPool` internals on each `/metrics` scrape.
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_redis_pool_in_use` | Gauge | `pool` | Checked-out connections |
| `onyx_redis_pool_available` | Gauge | `pool` | Idle connections |
| `onyx_redis_pool_max` | Gauge | `pool` | Configured max |
| `onyx_redis_pool_created` | Gauge | `pool` | Lifetime connections created |
Pool label values: `primary`, `replica`.
```promql
# Redis pool utilization (alert if > 80%)
onyx_redis_pool_in_use{pool="primary"} / onyx_redis_pool_max{pool="primary"}
```
## Event Loop Metrics
Always-on background asyncio task. Detects blocked event loops.
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_api_event_loop_lag_seconds` | Gauge | — | Current scheduling lag |
| `onyx_api_event_loop_lag_max_seconds` | Gauge | — | Max lag since process start |
Configurable via `EVENT_LOOP_LAG_PROBE_INTERVAL_SECONDS` (default `2.0`).
```promql
# Alert if event loop is blocked > 100ms
onyx_api_event_loop_lag_seconds > 0.1
```
## Thread Pool Metrics
Collected via `InstrumentedThreadPoolExecutor` (wraps `ThreadPoolExecutor` usage in `threadpool_concurrency.py`).
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_threadpool_tasks_submitted_total` | Counter | — | Total tasks submitted |
| `onyx_threadpool_tasks_active` | Gauge | — | Currently executing tasks |
| `onyx_threadpool_task_duration_seconds` | Histogram | — | Task execution duration |
| `onyx_process_thread_count` | Gauge | — | OS threads in the process |
```promql
# Rising thread count = potential leak
onyx_process_thread_count
```
## Deep Profiling Metrics (opt-in)
Requires `ENABLE_DEEP_PROFILING=true`. Adds ~10-20% allocation overhead.
### Configuration
| Env Var | Default | Description |
|---------|---------|-------------|
| `ENABLE_DEEP_PROFILING` | `false` | Enable tracemalloc + GC + object counting |
| `DEEP_PROFILING_SNAPSHOT_INTERVAL_SECONDS` | `60.0` | Interval between snapshots |
| `DEEP_PROFILING_TOP_N_ALLOCATIONS` | `20` | Top allocation sites to export |
| `DEEP_PROFILING_TOP_N_TYPES` | `30` | Top object types to export |
### Metrics
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_tracemalloc_top_bytes` | Gauge | `source` | Bytes by top allocation sites |
| `onyx_tracemalloc_top_count` | Gauge | `source` | Allocation count by source |
| `onyx_tracemalloc_delta_bytes` | Gauge | `source` | Growth since previous snapshot |
| `onyx_tracemalloc_total_bytes` | Gauge | — | Total traced memory |
| `onyx_gc_collections_total` | Counter | `generation` | GC runs per generation |
| `onyx_gc_collected_total` | Counter | `generation` | Objects collected |
| `onyx_gc_uncollectable_total` | Counter | `generation` | Uncollectable objects |
| `onyx_object_type_count` | Gauge | `type` | Live objects by type (top N) |
```promql
# Top leaking code locations
topk(10, onyx_tracemalloc_delta_bytes)
# GC uncollectable (true leaks)
rate(onyx_gc_uncollectable_total[5m])
```
## Admin Debug Endpoints (opt-in)
Requires `ENABLE_ADMIN_DEBUG_ENDPOINTS=true`. All require admin auth.
| Endpoint | Method | Returns |
|----------|--------|---------|
| `/admin/debug/process-info` | GET | RSS, VMS, CPU%, FD count, threads, uptime |
| `/admin/debug/pool-state` | GET | Postgres + Redis pool state as JSON |
| `/admin/debug/threads` | GET | All threads (name, daemon, ident) |
| `/admin/debug/event-loop-lag` | GET | Current + max event loop lag |
## Example PromQL Queries
### Which endpoints are saturated right now?