1
0
forked from github/onyx

feat: okta profile tool (#5184)

* Initial Okta profile tool

* Improve

* Fix

* Improve

* Improve

* Address EL comments
This commit is contained in:
Chris Weaver
2025-08-13 09:57:31 -07:00
committed by GitHub
parent fe194076c2
commit e0a9a6fb66
7 changed files with 328 additions and 17 deletions

View File

@@ -121,6 +121,8 @@ OAUTH_CLIENT_SECRET = (
os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET"))
or ""
)
# OpenID Connect configuration URL for Okta Profile Tool and other OIDC integrations
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or ""
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
@@ -617,6 +619,17 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Tool Configs
#####
OKTA_PROFILE_TOOL_ENABLED = (
os.environ.get("OKTA_PROFILE_TOOL_ENABLED", "").lower() == "true"
)
# API token for SSWS auth to Okta Admin API. If set, Users API will be used to enrich profile.
OKTA_API_TOKEN = os.environ.get("OKTA_API_TOKEN") or ""
#####
# Miscellaneous
#####

View File

@@ -72,6 +72,7 @@ class PreviousMessage(BaseModel):
message_type = MessageType.USER
elif isinstance(msg, AIMessage):
message_type = MessageType.ASSISTANT
message = message_to_string(msg)
return cls(
message=message,

View File

@@ -6,6 +6,7 @@ from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.app_configs import OKTA_PROFILE_TOOL_ENABLED
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.tools.tool_implementations.images.image_generation_tool import (
@@ -17,6 +18,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.internet_search.providers import (
get_available_providers,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
@@ -63,6 +67,19 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
if (bool(get_available_providers()))
else []
),
# Show Okta Profile tool if the environment variables are set
*(
[
InCodeToolInfo(
cls=OktaProfileTool,
description="The Okta Profile Action allows the assistant to fetch user information from Okta.",
in_code_tool_id=OktaProfileTool.__name__,
display_name=OktaProfileTool._DISPLAY_NAME,
)
]
if OKTA_PROFILE_TOOL_ENABLED
else []
),
]

View File

@@ -5,12 +5,12 @@ from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.utils.special_types import JSON_ro
if TYPE_CHECKING:
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
@@ -53,8 +53,8 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
history: list["PreviousMessage"],
llm: "LLM",
force_run: bool = False,
) -> dict[str, Any] | None:
raise NotImplementedError

View File

@@ -14,6 +14,10 @@ from onyx.configs.app_configs import AZURE_DALLE_API_KEY
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import OKTA_API_TOKEN
from onyx.configs.app_configs import OPENID_CONFIG_URL
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
@@ -41,6 +45,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import compute_all_tool_tokens
from onyx.tools.utils import explicit_tool_calling_supported
@@ -265,6 +272,33 @@ def construct_tools(
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
)
# Handle Okta Profile Tool
elif tool_cls.__name__ == OktaProfileTool.__name__:
if not user_oauth_token:
raise ValueError(
"Okta Profile Tool requires user OAuth token but none found"
)
if not all([OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL]):
raise ValueError(
"Okta Profile Tool requires OAuth configuration to be set"
)
if not OKTA_API_TOKEN:
raise ValueError(
"Okta Profile Tool requires OKTA_API_TOKEN to be set"
)
tool_dict[db_tool_model.id] = [
OktaProfileTool(
access_token=user_oauth_token,
client_id=OAUTH_CLIENT_ID,
client_secret=OAUTH_CLIENT_SECRET,
openid_config_url=OPENID_CONFIG_URL,
okta_api_token=OKTA_API_TOKEN,
)
]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:

View File

@@ -0,0 +1,243 @@
import json
from collections.abc import Generator
from typing import Any
from urllib.parse import urlparse
import requests
from pydantic import BaseModel
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import message_to_string
from onyx.prompts.constants import GENERAL_SEP_PAT
from onyx.tools.base_tool import BaseTool
from onyx.tools.models import ToolResponse
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
OKTA_PROFILE_RESPONSE_ID = "okta_profile"
OKTA_TOOL_DESCRIPTION = """
The Okta profile tool can retrieve user profile information from Okta including:
- User ID, status, creation date
- Profile details like name, email, department, location, title, manager, and more
- Account status and activity
"""
class OIDCConfig(BaseModel):
issuer: str
jwks_uri: str | None = None
userinfo_endpoint: str | None = None
introspection_endpoint: str | None = None
token_endpoint: str | None = None
class OktaProfileTool(BaseTool):
_NAME = "get_okta_profile"
_DESCRIPTION = "This tool is used to get the user's profile information."
_DISPLAY_NAME = "Okta Profile"
def __init__(
self,
access_token: str,
client_id: str,
client_secret: str,
openid_config_url: str,
okta_api_token: str,
request_timeout_sec: int = 15,
) -> None:
self.access_token = access_token
self.client_id = client_id
self.client_secret = client_secret
self.openid_config_url = openid_config_url
self.request_timeout_sec = request_timeout_sec
# Extract Okta org URL from OpenID config URL using URL parsing
# OpenID config URL format: https://{org}.okta.com/.well-known/openid_configuration
parsed_url = urlparse(self.openid_config_url)
self.okta_org_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
self.okta_api_token = okta_api_token
self._oidc_config: OIDCConfig | None = None
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {"type": "object", "properties": {}, "required": []},
},
}
def _load_oidc_config(self) -> OIDCConfig:
if self._oidc_config is not None:
return self._oidc_config
resp = requests.get(self.openid_config_url, timeout=self.request_timeout_sec)
resp.raise_for_status()
data = resp.json()
self._oidc_config = OIDCConfig(**data)
logger.debug(f"Loaded OIDC config from {self.openid_config_url}")
return self._oidc_config
def _call_userinfo(self, access_token: str) -> dict[str, Any] | None:
try:
cfg = self._load_oidc_config()
if not cfg.userinfo_endpoint:
logger.info("OIDC config missing userinfo_endpoint")
return None
headers = {"Authorization": f"Bearer {access_token}"}
r = requests.get(
cfg.userinfo_endpoint, headers=headers, timeout=self.request_timeout_sec
)
if r.status_code == 200:
return r.json()
logger.info(
f"userinfo call returned status {r.status_code}: {r.text[:200]}"
)
return None
except requests.RequestException as e:
logger.debug(f"userinfo request failed: {e}")
return None
def _call_introspection(self, access_token: str) -> dict[str, Any] | None:
try:
cfg = self._load_oidc_config()
if not cfg.introspection_endpoint:
logger.info("OIDC config missing introspection_endpoint")
return None
data = {
"token": access_token,
"token_type_hint": "access_token",
}
auth: tuple[str, str] | None = (self.client_id, self.client_secret)
r = requests.post(
cfg.introspection_endpoint,
data=data,
auth=auth,
headers={"Accept": "application/json"},
timeout=self.request_timeout_sec,
)
if r.status_code == 200:
return r.json()
logger.info(
f"introspection call returned status {r.status_code}: {r.text[:200]}"
)
return None
except requests.RequestException as e:
logger.debug(f"introspection request failed: {e}")
return None
def _call_users_api(self, uid: str) -> dict[str, Any]:
"""Call Okta Users API to fetch full user profile.
Requires okta_org_url and okta_api_token to be set. Raises exception on any error.
"""
if not self.okta_org_url or not self.okta_api_token:
raise ValueError(
"Okta org URL and API token are required for user profile lookup"
)
try:
url = f"{self.okta_org_url.rstrip('/')}/api/v1/users/{uid}"
headers = {"Authorization": f"SSWS {self.okta_api_token}"}
r = requests.get(url, headers=headers, timeout=self.request_timeout_sec)
if r.status_code == 200:
return r.json()
raise ValueError(
f"Okta Users API call failed with status {r.status_code}: {r.text[:200]}"
)
except requests.RequestException as e:
raise ValueError(f"Okta Users API request failed: {e}") from e
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
# The tool emits a single aggregated packet; pass it through as compact JSON
profile = args[-1].response if args else {}
return json.dumps(profile)
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
if force_run:
return {}
# Use LLM to determine if this tool should be called based on the query
prompt = f"""
You are helping to determine if an Okta profile lookup tool should be called based on a user's query.
{OKTA_TOOL_DESCRIPTION}
Query: {query}
Conversation history:
{GENERAL_SEP_PAT}
{history}
{GENERAL_SEP_PAT}
Should the Okta profile tool be called for this query? Respond with only "YES" or "NO".
""".strip()
response = llm.invoke(prompt)
if response and "YES" in message_to_string(response).upper():
return {}
return None
def run(
self, override_kwargs: None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
# Try to get UID from userinfo first, then fallback to introspection
uid_candidate = None
# Try userinfo endpoint first
userinfo_data = self._call_userinfo(self.access_token)
if userinfo_data and isinstance(userinfo_data, dict):
uid_candidate = userinfo_data.get("uid")
# Only try introspection if userinfo didn't provide a UID
if not uid_candidate:
introspection_data = self._call_introspection(self.access_token)
if introspection_data and isinstance(introspection_data, dict):
uid_candidate = introspection_data.get("uid")
if not uid_candidate:
raise ValueError(
"Unable to fetch user profile from Okta. This likely means your Okta "
"token has expired. Please logout, log back in, and try again."
)
# Call Users API to get full profile - this is now required
users_api_data = self._call_users_api(uid_candidate)
yield ToolResponse(
id=OKTA_PROFILE_RESPONSE_ID, response=users_api_data["profile"]
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
# Return the single aggregated profile packet
if not args:
return {}
return args[-1].response

View File

@@ -554,29 +554,32 @@ export const AIMessage = ({
)}
</>
) : null)}
{userKnowledgeFiles && (
{userKnowledgeFiles.length > 0 && (
<UserKnowledgeFiles
userKnowledgeFiles={userKnowledgeFiles}
/>
)}
{!userKnowledgeFiles &&
{userKnowledgeFiles.length === 0 &&
toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(
toolCall.tool_name
) && (
<ToolRunDisplay
toolName={
toolCall.tool_result && content
? `Used "${toolCall.tool_name}"`
: `Using "${toolCall.tool_name}"`
}
toolLogo={
<FiTool size={15} className="my-auto mr-1" />
}
isRunning={!toolCall.tool_result || !content}
/>
<div className="mb-1">
<ToolRunDisplay
toolName={
toolCall.tool_result && content
? `Used "${toolCall.tool_name}"`
: `Using "${toolCall.tool_name}"`
}
toolLogo={
<FiTool size={15} className="my-auto mr-1" />
}
isRunning={!toolCall.tool_result || !content}
/>
</div>
)}
{toolCall &&
(!files || files.length == 0) &&
toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME &&