mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-04 22:42:41 +00:00
Compare commits
6 Commits
cli/v0.2.1
...
dane/litel
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfe529f214 | ||
|
|
a48b95256f | ||
|
|
2d600036d3 | ||
|
|
d708c6f882 | ||
|
|
d1f6f8369d | ||
|
|
5115b621c8 |
@@ -1,4 +1,8 @@
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -49,6 +53,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_env_lock = threading.Lock()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import HTTPHandler
|
||||
@@ -378,23 +384,30 @@ class LitellmLLM(LLM):
|
||||
if "api_key" not in passthrough_kwargs:
|
||||
passthrough_kwargs["api_key"] = self._api_key or None
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
# We only need to set environment variables if custom config is set
|
||||
env_ctx = (
|
||||
temporary_env_and_lock(self._custom_config)
|
||||
if self._custom_config
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# for break pointing
|
||||
@@ -475,13 +488,21 @@ class LitellmLLM(LLM):
|
||||
client = HTTPHandler(timeout=timeout_override or self._timeout)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
# When custom_config is set, env vars are temporarily injected
|
||||
# under a global lock. Using stream=True here means the lock is
|
||||
# only held during connection setup (not the full inference).
|
||||
# The chunks are then collected outside the lock and reassembled
|
||||
# into a single ModelResponse via stream_chunk_builder.
|
||||
from litellm import stream_chunk_builder
|
||||
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
|
||||
|
||||
stream_response = cast(
|
||||
LiteLLMCustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
@@ -491,6 +512,11 @@ class LitellmLLM(LLM):
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
chunks = list(stream_response)
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
stream_chunk_builder(chunks),
|
||||
)
|
||||
|
||||
model_response = from_litellm_model_response(response)
|
||||
|
||||
@@ -581,3 +607,29 @@ class LitellmLLM(LLM):
|
||||
finally:
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
|
||||
"""
|
||||
Temporarily sets the environment variables to the given values.
|
||||
Code path is locked while the environment variables are set.
|
||||
Then cleans up the environment and frees the lock.
|
||||
"""
|
||||
with _env_lock:
|
||||
logger.debug("Acquired lock in temporary_env_and_lock")
|
||||
# Store original values (None if key didn't exist)
|
||||
original_values: dict[str, str | None] = {
|
||||
key: os.environ.get(key) for key in env_variables
|
||||
}
|
||||
try:
|
||||
os.environ.update(env_variables)
|
||||
yield
|
||||
finally:
|
||||
for key, original_value in original_values.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None) # Remove if it didn't exist before
|
||||
else:
|
||||
os.environ[key] = original_value # Restore original value
|
||||
|
||||
logger.debug("Released lock in temporary_env_and_lock")
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -137,42 +141,44 @@ def default_multi_llm() -> LitellmLLM:
|
||||
def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
|
||||
# Mock the litellm.completion function
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create a mock response with multiple tool calls using litellm objects
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
litellm.ChatCompletionMessageToolCall(
|
||||
id="call_1",
|
||||
function=LiteLLMFunction(
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York"}',
|
||||
# invoke() internally uses stream=True and reassembles via
|
||||
# stream_chunk_builder, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_1",
|
||||
function=LiteLLMFunction(
|
||||
name="get_weather",
|
||||
arguments='{"location": "New York"}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
litellm.ChatCompletionMessageToolCall(
|
||||
id="call_2",
|
||||
function=LiteLLMFunction(
|
||||
name="get_time", arguments='{"timezone": "EST"}'
|
||||
ChatCompletionDeltaToolCall(
|
||||
id="call_2",
|
||||
function=LiteLLMFunction(
|
||||
name="get_time",
|
||||
arguments='{"timezone": "EST"}',
|
||||
),
|
||||
type="function",
|
||||
index=1,
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
usage=litellm.Usage(
|
||||
prompt_tokens=50, completion_tokens=30, total_tokens=80
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
# Define input messages
|
||||
messages: LanguageModelInput = [
|
||||
@@ -246,11 +252,12 @@ def test_multiple_tool_calls(default_multi_llm: LitellmLLM) -> None:
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
stream=False,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
max_tokens=None,
|
||||
client=ANY, # HTTPHandler instance created per-request
|
||||
stream_options={"include_usage": True},
|
||||
parallel_tool_calls=True,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
allowed_openai_params=["tool_choice"],
|
||||
@@ -507,21 +514,20 @@ def test_openai_chat_omits_reasoning_params() -> None:
|
||||
"onyx.llm.multi_llm.is_true_openai_model", return_value=True
|
||||
) as mock_is_openai,
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-5-chat",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -539,21 +545,20 @@ def test_user_identity_metadata_enabled(default_multi_llm: LitellmLLM) -> None:
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
@@ -573,21 +578,20 @@ def test_user_identity_user_id_truncated_to_64_chars(
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", True),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
long_user_id = "u" * 82
|
||||
@@ -607,21 +611,20 @@ def test_user_identity_metadata_disabled_omits_identity(
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
@@ -654,21 +657,20 @@ def test_existing_metadata_pass_through_when_identity_disabled() -> None:
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(
|
||||
content="Hello",
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
@@ -688,18 +690,20 @@ def test_openai_model_invoke_uses_httphandler_client(
|
||||
from litellm import HTTPHandler
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
default_multi_llm.invoke(messages)
|
||||
@@ -737,18 +741,20 @@ def test_anthropic_model_passes_no_client() -> None:
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="claude-3-opus-20240229",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="claude-3-opus-20240229",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -769,18 +775,20 @@ def test_bedrock_model_passes_no_client() -> None:
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -809,18 +817,20 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_response = litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="Hello", role="assistant"),
|
||||
)
|
||||
],
|
||||
model="gpt-4o",
|
||||
)
|
||||
mock_completion.return_value = mock_response
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-4o",
|
||||
),
|
||||
]
|
||||
mock_completion.return_value = mock_stream_chunks
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
@@ -828,3 +838,372 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
|
||||
mock_completion.assert_called_once()
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert isinstance(kwargs["client"], HTTPHandler)
|
||||
|
||||
|
||||
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Assign some environment variables
|
||||
EXPECTED_ENV_VARS = {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
"ANOTHER_ONE": "1",
|
||||
"THIRD_ONE": "2",
|
||||
}
|
||||
|
||||
CUSTOM_CONFIG = {
|
||||
"TEST_ENV_VAR": "fdsfsdf",
|
||||
"ANOTHER_ONE": "3",
|
||||
"THIS_IS_RANDOM": "123213",
|
||||
}
|
||||
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
monkeypatch.setenv(env_var, value)
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
model_kwargs={"metadata": {"foo": "bar"}},
|
||||
custom_config=CUSTOM_CONFIG,
|
||||
)
|
||||
|
||||
# When custom_config is set, invoke() internally uses stream=True and
|
||||
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
|
||||
def on_litellm_completion(
|
||||
**kwargs: dict[str, Any], # noqa: ARG001
|
||||
) -> list[litellm.ModelResponse]:
|
||||
# Validate that the environment variables are those in custom config
|
||||
for env_var, value in CUSTOM_CONFIG.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
return mock_stream_chunks
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_completion.side_effect = on_litellm_completion
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
|
||||
llm.invoke(messages, user_identity=identity)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert kwargs["stream"] is True
|
||||
assert "user" not in kwargs
|
||||
assert kwargs["metadata"]["foo"] == "bar"
|
||||
|
||||
# Check that the environment variables are back to the original values
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Check that temporary env var from CUSTOM_CONFIG is no longer set
|
||||
assert "THIS_IS_RANDOM" not in os.environ
|
||||
|
||||
|
||||
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
|
||||
# Assign some environment variables
|
||||
EXPECTED_ENV_VARS = {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
"ANOTHER_ONE": "1",
|
||||
"THIRD_ONE": "2",
|
||||
}
|
||||
|
||||
CUSTOM_CONFIG = {
|
||||
"TEST_ENV_VAR": "fdsfsdf",
|
||||
"ANOTHER_ONE": "3",
|
||||
"THIS_IS_RANDOM": "123213",
|
||||
}
|
||||
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
monkeypatch.setenv(env_var, value)
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
model_kwargs={"metadata": {"foo": "bar"}},
|
||||
custom_config=CUSTOM_CONFIG,
|
||||
)
|
||||
|
||||
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
|
||||
# Validate that the environment variables are those in custom config
|
||||
for env_var, value in CUSTOM_CONFIG.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Simulate an error during LLM call
|
||||
raise RuntimeError("Simulated LLM API failure")
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_completion.side_effect = on_litellm_completion_raises
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
|
||||
llm.invoke(messages, user_identity=identity)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
# Check that the environment variables are back to the original values
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Check that temporary env var from CUSTOM_CONFIG is no longer set
|
||||
assert "THIS_IS_RANDOM" not in os.environ
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_stream", [False, True], ids=["invoke", "stream"])
|
||||
def test_multithreaded_custom_config_isolation(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
use_stream: bool,
|
||||
) -> None:
|
||||
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
|
||||
|
||||
Two LitellmLLM instances with different custom_config dicts call invoke/stream
|
||||
concurrently. The _env_lock in temporary_env_and_lock serializes their access so
|
||||
each call only ever sees its own env vars—never the other's.
|
||||
"""
|
||||
# Ensure these keys start unset
|
||||
monkeypatch.delenv("SHARED_KEY", raising=False)
|
||||
monkeypatch.delenv("LLM_A_ONLY", raising=False)
|
||||
monkeypatch.delenv("LLM_B_ONLY", raising=False)
|
||||
|
||||
CONFIG_A = {
|
||||
"SHARED_KEY": "value_from_A",
|
||||
"LLM_A_ONLY": "a_secret",
|
||||
}
|
||||
CONFIG_B = {
|
||||
"SHARED_KEY": "value_from_B",
|
||||
"LLM_B_ONLY": "b_secret",
|
||||
}
|
||||
|
||||
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_a = LitellmLLM(
|
||||
api_key="key_a",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
custom_config=CONFIG_A,
|
||||
)
|
||||
llm_b = LitellmLLM(
|
||||
api_key="key_b",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
custom_config=CONFIG_B,
|
||||
)
|
||||
|
||||
# Both invoke (with custom_config) and stream use stream=True at the
|
||||
# litellm level, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hi"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
]
|
||||
|
||||
# Track what each call observed inside litellm.completion.
|
||||
# Keyed by api_key so we can identify which LLM instance made the call.
|
||||
observed_envs: dict[str, dict[str, str | None]] = {}
|
||||
|
||||
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
|
||||
time.sleep(0.1) # We expect someone to get caught on the lock
|
||||
api_key = kwargs.get("api_key", "")
|
||||
label = "A" if api_key == "key_a" else "B"
|
||||
|
||||
snapshot: dict[str, str | None] = {}
|
||||
for key in all_env_keys:
|
||||
snapshot[key] = os.environ.get(key)
|
||||
observed_envs[label] = snapshot
|
||||
|
||||
return mock_stream_chunks
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run_llm(llm: LitellmLLM) -> None:
|
||||
try:
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
if use_stream:
|
||||
list(llm.stream(messages))
|
||||
else:
|
||||
llm.invoke(messages)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
with patch("litellm.completion", side_effect=fake_completion):
|
||||
t_a = threading.Thread(target=run_llm, args=(llm_a,))
|
||||
t_b = threading.Thread(target=run_llm, args=(llm_b,))
|
||||
|
||||
t_a.start()
|
||||
t_b.start()
|
||||
t_a.join(timeout=10)
|
||||
t_b.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert "A" in observed_envs and "B" in observed_envs
|
||||
|
||||
# Thread A must have seen its own config for SHARED_KEY, not B's
|
||||
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
|
||||
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
|
||||
# A must NOT see B's exclusive key
|
||||
assert observed_envs["A"]["LLM_B_ONLY"] is None
|
||||
|
||||
# Thread B must have seen its own config for SHARED_KEY, not A's
|
||||
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
|
||||
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
|
||||
# B must NOT see A's exclusive key
|
||||
assert observed_envs["B"]["LLM_A_ONLY"] is None
|
||||
|
||||
# After both calls, env should be clean
|
||||
assert os.environ.get("SHARED_KEY") is None
|
||||
assert os.environ.get("LLM_A_ONLY") is None
|
||||
assert os.environ.get("LLM_B_ONLY") is None
|
||||
|
||||
|
||||
def test_multithreaded_invoke_without_custom_config_skips_env_lock() -> None:
|
||||
"""Verify that invoke() without custom_config does not acquire the env lock.
|
||||
|
||||
Two LitellmLLM instances without custom_config call invoke concurrently.
|
||||
Both should run with stream=False, never touch the env lock, and complete
|
||||
without blocking each other.
|
||||
"""
|
||||
from onyx.llm import multi_llm as multi_llm_module
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_a = LitellmLLM(
|
||||
api_key="key_a",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
)
|
||||
llm_b = LitellmLLM(
|
||||
api_key="key_b",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
)
|
||||
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hi"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
]
|
||||
|
||||
call_kwargs: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
|
||||
api_key = kwargs.get("api_key", "")
|
||||
label = "A" if api_key == "key_a" else "B"
|
||||
call_kwargs[label] = kwargs
|
||||
return mock_stream_chunks
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run_llm(llm: LitellmLLM) -> None:
|
||||
try:
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
with (
|
||||
patch("litellm.completion", side_effect=fake_completion),
|
||||
patch.object(
|
||||
multi_llm_module,
|
||||
"temporary_env_and_lock",
|
||||
wraps=multi_llm_module.temporary_env_and_lock,
|
||||
) as mock_env_lock,
|
||||
):
|
||||
t_a = threading.Thread(target=run_llm, args=(llm_a,))
|
||||
t_b = threading.Thread(target=run_llm, args=(llm_b,))
|
||||
|
||||
t_a.start()
|
||||
t_b.start()
|
||||
t_a.join(timeout=10)
|
||||
t_b.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert "A" in call_kwargs and "B" in call_kwargs
|
||||
|
||||
# invoke() always uses stream=True internally (reassembles via stream_chunk_builder)
|
||||
assert call_kwargs["A"]["stream"] is True
|
||||
assert call_kwargs["B"]["stream"] is True
|
||||
|
||||
# The env lock context manager should never have been called
|
||||
mock_env_lock.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user