Compare commits

...

6 Commits

Author SHA1 Message Date
Dane Urban
bfe529f214 fix tests 2026-02-13 10:00:43 -08:00
Dane Urban
a48b95256f Use stream 2026-02-12 20:46:41 -08:00
Dane Urban
2d600036d3 Add comment 2026-02-11 15:37:12 -08:00
Dane Urban
d708c6f882 nts 2026-02-10 22:08:05 -08:00
Dane Urban
d1f6f8369d Parameterise the tests for stream and invoke 2026-02-10 22:05:18 -08:00
Dane Urban
5115b621c8 n 2026-02-10 21:01:32 -08:00
2 changed files with 607 additions and 176 deletions

View File

@@ -1,4 +1,8 @@
import os
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from contextlib import nullcontext
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -49,6 +53,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_env_lock = threading.Lock()
if TYPE_CHECKING:
from litellm import CustomStreamWrapper
from litellm import HTTPHandler
@@ -378,23 +384,30 @@ class LitellmLLM(LLM):
if "api_key" not in passthrough_kwargs:
passthrough_kwargs["api_key"] = self._api_key or None
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
# We only need to set environment variables if custom config is set
env_ctx = (
temporary_env_and_lock(self._custom_config)
if self._custom_config
else nullcontext()
)
with env_ctx:
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
)
return response
except Exception as e:
# for break pointing
@@ -475,13 +488,21 @@ class LitellmLLM(LLM):
client = HTTPHandler(timeout=timeout_override or self._timeout)
try:
response = cast(
LiteLLMModelResponse,
# When custom_config is set, env vars are temporarily injected
# under a global lock. Using stream=True here means the lock is
# only held during connection setup (not the full inference).
# The chunks are then collected outside the lock and reassembled
# into a single ModelResponse via stream_chunk_builder.
from litellm import stream_chunk_builder
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
stream_response = cast(
LiteLLMCustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
@@ -491,6 +512,11 @@ class LitellmLLM(LLM):
client=client,
),
)
chunks = list(stream_response)
response = cast(
LiteLLMModelResponse,
stream_chunk_builder(chunks),
)
model_response = from_litellm_model_response(response)
@@ -581,3 +607,29 @@ class LitellmLLM(LLM):
finally:
if client is not None:
client.close()
@contextmanager
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
"""
Temporarily sets the environment variables to the given values.
Code path is locked while the environment variables are set.
Then cleans up the environment and frees the lock.
"""
with _env_lock:
logger.debug("Acquired lock in temporary_env_and_lock")
# Store original values (None if key didn't exist)
original_values: dict[str, str | None] = {
key: os.environ.get(key) for key in env_variables
}
try:
os.environ.update(env_variables)
yield
finally:
for key, original_value in original_values.items():
if original_value is None:
os.environ.pop(key, None) # Remove if it didn't exist before
else:
os.environ[key] = original_value # Restore original value
logger.debug("Released lock in temporary_env_and_lock")

View File

@@ -1,3 +1,7 @@
import os
import threading
import time
from typing import Any
from unittest.mock import ANY
from unittest.mock import patch
@@ -137,42 +141,44 @@ def default_multi_llm() -> LitellmLLM:
def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
# Mock the litellm.completion function
with patch("litellm.completion") as mock_completion:
# Create a mock response with multiple tool calls using litellm objects
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="tool_calls",
index=0,
message=litellm.Message(
content=None,
role="assistant",
tool_calls=[
litellm.ChatCompletionMessageToolCall(
id="call_1",
function=LiteLLMFunction(
name="get_weather",
arguments='{"location": "New York"}',
# invoke() internally uses stream=True and reassembles via
# stream_chunk_builder, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
id="call_1",
function=LiteLLMFunction(
name="get_weather",
arguments='{"location": "New York"}',
),
type="function",
index=0,
),
type="function",
),
litellm.ChatCompletionMessageToolCall(
id="call_2",
function=LiteLLMFunction(
name="get_time", arguments='{"timezone": "EST"}'
ChatCompletionDeltaToolCall(
id="call_2",
function=LiteLLMFunction(
name="get_time",
arguments='{"timezone": "EST"}',
),
type="function",
index=1,
),
type="function",
),
],
),
)
],
model="gpt-3.5-turbo",
usage=litellm.Usage(
prompt_tokens=50, completion_tokens=30, total_tokens=80
],
),
finish_reason="tool_calls",
index=0,
)
],
model="gpt-3.5-turbo",
),
)
mock_completion.return_value = mock_response
]
mock_completion.return_value = mock_stream_chunks
# Define input messages
messages: LanguageModelInput = [
@@ -246,11 +252,12 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
],
tools=tools,
tool_choice=None,
stream=False,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
max_tokens=None,
client=ANY, # HTTPHandler instance created per-request
stream_options={"include_usage": True},
parallel_tool_calls=True,
mock_response=MOCK_LLM_RESPONSE,
allowed_openai_params=["tool_choice"],
@@ -507,21 +514,20 @@ def test_openai_chat_omits_reasoning_params() -> None:
"onyx.llm.multi_llm.is_true_openai_model", return_value=True
) as mock_is_openai,
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-5-chat",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-5-chat",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -539,21 +545,20 @@ def test_user_identity_metadata_enabled(default_multi_llm: LitellmLLM) -> None:
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
@@ -573,21 +578,20 @@ def test_user_identity_user_id_truncated_to_64_chars(
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
long_user_id = "u" * 82
@@ -607,21 +611,20 @@ def test_user_identity_metadata_disabled_omits_identity(
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
@@ -654,21 +657,20 @@ def test_existing_metadata_pass_through_when_identity_disabled() -> None:
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(
content="Hello",
role="assistant",
),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
@@ -688,18 +690,20 @@ def test_openai_model_invoke_uses_httphandler_client(
from litellm import HTTPHandler
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="gpt-3.5-turbo",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
default_multi_llm.invoke(messages)
@@ -737,18 +741,20 @@ def test_anthropic_model_passes_no_client() -> None:
)
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="claude-3-opus-20240229",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="claude-3-opus-20240229",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -769,18 +775,20 @@ def test_bedrock_model_passes_no_client() -> None:
)
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="anthropic.claude-3-sonnet-20240229-v1:0",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -809,18 +817,20 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
)
with patch("litellm.completion") as mock_completion:
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="Hello", role="assistant"),
)
],
model="gpt-4o",
)
mock_completion.return_value = mock_response
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-4o",
),
]
mock_completion.return_value = mock_stream_chunks
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
@@ -828,3 +838,372 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert isinstance(kwargs["client"], HTTPHandler)
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
# When custom_config is set, invoke() internally uses stream=True and
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
def on_litellm_completion(
**kwargs: dict[str, Any], # noqa: ARG001
) -> list[litellm.ModelResponse]:
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
return mock_stream_chunks
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert kwargs["stream"] is True
assert "user" not in kwargs
assert kwargs["metadata"]["foo"] == "bar"
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Simulate an error during LLM call
raise RuntimeError("Simulated LLM API failure")
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion_raises
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
@pytest.mark.parametrize("use_stream", [False, True], ids=["invoke", "stream"])
def test_multithreaded_custom_config_isolation(
monkeypatch: pytest.MonkeyPatch,
use_stream: bool,
) -> None:
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
Two LitellmLLM instances with different custom_config dicts call invoke/stream
concurrently. The _env_lock in temporary_env_and_lock serializes their access so
each call only ever sees its own env vars—never the other's.
"""
# Ensure these keys start unset
monkeypatch.delenv("SHARED_KEY", raising=False)
monkeypatch.delenv("LLM_A_ONLY", raising=False)
monkeypatch.delenv("LLM_B_ONLY", raising=False)
CONFIG_A = {
"SHARED_KEY": "value_from_A",
"LLM_A_ONLY": "a_secret",
}
CONFIG_B = {
"SHARED_KEY": "value_from_B",
"LLM_B_ONLY": "b_secret",
}
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm_a = LitellmLLM(
api_key="key_a",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_A,
)
llm_b = LitellmLLM(
api_key="key_b",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_B,
)
# Both invoke (with custom_config) and stream use stream=True at the
# litellm level, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hi"),
finish_reason="stop",
index=0,
)
],
model=model_name,
),
]
# Track what each call observed inside litellm.completion.
# Keyed by api_key so we can identify which LLM instance made the call.
observed_envs: dict[str, dict[str, str | None]] = {}
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
time.sleep(0.1) # We expect someone to get caught on the lock
api_key = kwargs.get("api_key", "")
label = "A" if api_key == "key_a" else "B"
snapshot: dict[str, str | None] = {}
for key in all_env_keys:
snapshot[key] = os.environ.get(key)
observed_envs[label] = snapshot
return mock_stream_chunks
errors: list[Exception] = []
def run_llm(llm: LitellmLLM) -> None:
try:
messages: LanguageModelInput = [UserMessage(content="Hi")]
if use_stream:
list(llm.stream(messages))
else:
llm.invoke(messages)
except Exception as e:
errors.append(e)
with patch("litellm.completion", side_effect=fake_completion):
t_a = threading.Thread(target=run_llm, args=(llm_a,))
t_b = threading.Thread(target=run_llm, args=(llm_b,))
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert "A" in observed_envs and "B" in observed_envs
# Thread A must have seen its own config for SHARED_KEY, not B's
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
# A must NOT see B's exclusive key
assert observed_envs["A"]["LLM_B_ONLY"] is None
# Thread B must have seen its own config for SHARED_KEY, not A's
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
# B must NOT see A's exclusive key
assert observed_envs["B"]["LLM_A_ONLY"] is None
# After both calls, env should be clean
assert os.environ.get("SHARED_KEY") is None
assert os.environ.get("LLM_A_ONLY") is None
assert os.environ.get("LLM_B_ONLY") is None
def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
"""Verify that invoke() without custom_config does not acquire the env lock.
Two LitellmLLM instances without custom_config call invoke concurrently.
Both should run with stream=False, never touch the env lock, and complete
without blocking each other.
"""
from onyx.llm import multi_llm as multi_llm_module
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm_a = LitellmLLM(
api_key="key_a",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
)
llm_b = LitellmLLM(
api_key="key_b",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
)
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hi"),
finish_reason="stop",
index=0,
)
],
model=model_name,
),
]
call_kwargs: dict[str, dict[str, Any]] = {}
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
api_key = kwargs.get("api_key", "")
label = "A" if api_key == "key_a" else "B"
call_kwargs[label] = kwargs
return mock_stream_chunks
errors: list[Exception] = []
def run_llm(llm: LitellmLLM) -> None:
try:
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
except Exception as e:
errors.append(e)
with (
patch("litellm.completion", side_effect=fake_completion),
patch.object(
multi_llm_module,
"temporary_env_and_lock",
wraps=multi_llm_module.temporary_env_and_lock,
) as mock_env_lock,
):
t_a = threading.Thread(target=run_llm, args=(llm_a,))
t_b = threading.Thread(target=run_llm, args=(llm_b,))
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert "A" in call_kwargs and "B" in call_kwargs
# invoke() always uses stream=True internally (reassembles via stream_chunk_builder)
assert call_kwargs["A"]["stream"] is True
assert call_kwargs["B"]["stream"] is True
# The env lock context manager should never have been called
mock_env_lock.assert_not_called()