Compare commits

...

1 Commits

Author SHA1 Message Date
Evan Lohn
bf668b5063 fix: max tokens param (#5174)
* max tokens param

* fix unit test

* fix unit test
2025-08-11 12:28:06 -07:00
2 changed files with 12 additions and 3 deletions

View File

@@ -24,6 +24,7 @@ from langchain_core.messages import SystemMessageChunk
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from litellm.utils import get_supported_openai_params
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
@@ -52,6 +53,8 @@ litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
VERTEX_LOCATION_KWARG = "vertex_location"
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
class LLMTimeoutError(Exception):
@@ -313,6 +316,14 @@ class DefaultMultiLLM(LLM):
self._model_kwargs = model_kwargs
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
try:
params = get_supported_openai_params(model_name, model_provider)
if STANDARD_MAX_TOKENS_KWARG in (params or []):
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
except Exception as e:
logger.warning(f"Error getting supported openai params: {e}")
def _safe_model_config(self) -> dict:
dump = self.config.model_dump()
dump["api_key"] = mask_string(dump.get("api_key", ""))
@@ -393,7 +404,6 @@ class DefaultMultiLLM(LLM):
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
max_tokens=max_tokens,
# streaming choice
stream=stream,
# model params
@@ -426,6 +436,7 @@ class DefaultMultiLLM(LLM):
if structured_response_format
else {}
),
**({self._max_token_param: max_tokens} if max_tokens else {}),
**self._model_kwargs,
)
except Exception as e:

View File

@@ -148,7 +148,6 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=False,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
@@ -294,7 +293,6 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,