Compare commits

...

1 Commits

Author SHA1 Message Date
Dane Urban
5115b621c8 n 2026-02-10 21:01:32 -08:00
2 changed files with 376 additions and 32 deletions

View File

@@ -1,4 +1,8 @@
import os
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from contextlib import nullcontext
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -49,6 +53,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_env_lock = threading.Lock()
if TYPE_CHECKING:
from litellm import CustomStreamWrapper
from litellm import HTTPHandler
@@ -378,23 +384,29 @@ class LitellmLLM(LLM):
if "api_key" not in passthrough_kwargs:
passthrough_kwargs["api_key"] = self._api_key or None
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
env_ctx = (
temporary_env_and_lock(self._custom_config)
if self._custom_config
else nullcontext()
)
with env_ctx:
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
)
return response
except Exception as e:
# for break pointing
@@ -475,22 +487,53 @@ class LitellmLLM(LLM):
client = HTTPHandler(timeout=timeout_override or self._timeout)
try:
response = cast(
LiteLLMModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
if self._custom_config:
# When custom_config is set, env vars are temporarily injected
# under a global lock. Using stream=True here means the lock is
# only held during connection setup (not the full inference).
# The chunks are then collected outside the lock and reassembled
# into a single ModelResponse via stream_chunk_builder.
from litellm import stream_chunk_builder
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
stream_response = cast(
LiteLLMCustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
chunks = list(stream_response)
response = cast(
LiteLLMModelResponse,
stream_chunk_builder(chunks),
)
else:
response = cast(
LiteLLMModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
model_response = from_litellm_model_response(response)
@@ -581,3 +624,29 @@ class LitellmLLM(LLM):
finally:
if client is not None:
client.close()
@contextmanager
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
"""
Temporarily sets the environment variables to the given values.
Code path is locked while the environment variables are set.
Then cleans up the environment and frees the lock.
"""
with _env_lock:
logger.debug("Acquired lock in temporary_env_and_lock")
# Store original values (None if key didn't exist)
original_values: dict[str, str | None] = {
key: os.environ.get(key) for key in env_variables
}
try:
os.environ.update(env_variables)
yield
finally:
for key, original_value in original_values.items():
if original_value is None:
os.environ.pop(key, None) # Remove if it didn't exist before
else:
os.environ[key] = original_value # Restore original value
logger.debug("Released lock in temporary_env_and_lock")

View File

@@ -1,3 +1,7 @@
import os
import threading
import time
from typing import Any
from unittest.mock import ANY
from unittest.mock import patch
@@ -828,3 +832,274 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert isinstance(kwargs["client"], HTTPHandler)
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
# When custom_config is set, invoke() internally uses stream=True and
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
def on_litellm_completion(
**kwargs: dict[str, Any], # noqa: ARG001
) -> list[litellm.ModelResponse]:
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
return mock_stream_chunks
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert kwargs["stream"] is True
assert "user" not in kwargs
assert kwargs["metadata"]["foo"] == "bar"
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Simulate an error during LLM call
raise RuntimeError("Simulated LLM API failure")
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion_raises
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
def test_multithreaded_custom_config_isolation(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
Two LitellmLLM instances with different custom_config dicts invoke concurrently.
The _env_lock in temporary_env_and_lock serializes their access so each call only
ever sees its own env vars—never the other's.
"""
# Ensure these keys start unset
monkeypatch.delenv("SHARED_KEY", raising=False)
monkeypatch.delenv("LLM_A_ONLY", raising=False)
monkeypatch.delenv("LLM_B_ONLY", raising=False)
CONFIG_A = {
"SHARED_KEY": "value_from_A",
"LLM_A_ONLY": "a_secret",
}
CONFIG_B = {
"SHARED_KEY": "value_from_B",
"LLM_B_ONLY": "b_secret",
}
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm_a = LitellmLLM(
api_key="key_a",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_A,
)
llm_b = LitellmLLM(
api_key="key_b",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_B,
)
# invoke() uses stream=True internally when custom_config is set,
# so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hi"),
finish_reason="stop",
index=0,
)
],
model=model_name,
),
]
# Track what each call observed inside litellm.completion.
# Keyed by api_key so we can identify which LLM instance made the call.
observed_envs: dict[str, dict[str, str | None]] = {}
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
time.sleep(0.1) # We expect someone to get caught on the lock
api_key = kwargs.get("api_key", "")
label = "A" if api_key == "key_a" else "B"
snapshot: dict[str, str | None] = {}
for key in all_env_keys:
snapshot[key] = os.environ.get(key)
observed_envs[label] = snapshot
return mock_stream_chunks
errors: list[Exception] = []
def run_llm(llm: LitellmLLM) -> None:
try:
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
except Exception as e:
errors.append(e)
with patch("litellm.completion", side_effect=fake_completion):
t_a = threading.Thread(target=run_llm, args=(llm_a,))
t_b = threading.Thread(target=run_llm, args=(llm_b,))
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert "A" in observed_envs and "B" in observed_envs
# Thread A must have seen its own config for SHARED_KEY, not B's
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
# A must NOT see B's exclusive key
assert observed_envs["A"]["LLM_B_ONLY"] is None
# Thread B must have seen its own config for SHARED_KEY, not A's
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
# B must NOT see A's exclusive key
assert observed_envs["B"]["LLM_A_ONLY"] is None
# After both calls, env should be clean
assert os.environ.get("SHARED_KEY") is None
assert os.environ.get("LLM_A_ONLY") is None
assert os.environ.get("LLM_B_ONLY") is None