mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 08:45:47 +00:00
Compare commits
2 Commits
dump-scrip
...
v2.1.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc0e084eef | ||
|
|
460f19d2f0 |
13
.vscode/env_template.txt
vendored
13
.vscode/env_template.txt
vendored
@@ -1,6 +1,6 @@
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
@@ -37,8 +37,8 @@ OPENAI_API_KEY=<REPLACE THIS>
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using DanswerBot
|
||||
# For Onyx Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using OnyxBot
|
||||
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
@@ -75,4 +75,9 @@ SHOW_EXTRA_CONNECTORS=True
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
|
||||
# Local Confluence OAuth testing
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
@@ -139,19 +139,13 @@ def get_all_space_permissions(
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
while True:
|
||||
spaces_batch = confluence_client.get_all_spaces(
|
||||
start=start, limit=REQUEST_PAGINATION_LIMIT
|
||||
all_space_keys = [
|
||||
key
|
||||
for space in confluence_client.retrieve_confluence_spaces(
|
||||
limit=REQUEST_PAGINATION_LIMIT,
|
||||
)
|
||||
for space in spaces_batch.get("results", []):
|
||||
all_space_keys.append(space.get("key"))
|
||||
|
||||
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
|
||||
break
|
||||
|
||||
start += len(spaces_batch.get("results", []))
|
||||
if (key := space.get("key"))
|
||||
]
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
|
||||
@@ -76,6 +76,7 @@ class ConfluenceCloudOAuth:
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"read:space:confluence%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
|
||||
@@ -109,13 +109,11 @@ from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.saml import get_saml_account
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.secrets import extract_hashed_cookie
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -1064,17 +1062,7 @@ async def _check_for_saml_and_jwt(
|
||||
user: User | None,
|
||||
async_db_session: AsyncSession,
|
||||
) -> User | None:
|
||||
# Check if the user has a session cookie from SAML
|
||||
if AUTH_TYPE == AuthType.SAML:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
# If user is None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
|
||||
@@ -21,7 +21,6 @@ GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
# Cookies
|
||||
FASTAPI_USERS_AUTH_COOKIE_NAME = (
|
||||
|
||||
@@ -744,7 +744,10 @@ class ConfluenceConnector(
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
try:
|
||||
spaces = self.low_timeout_confluence_client.get_all_spaces(limit=1)
|
||||
spaces_iter = self.low_timeout_confluence_client.retrieve_confluence_spaces(
|
||||
limit=1,
|
||||
)
|
||||
first_space = next(spaces_iter, None)
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response else None
|
||||
if status_code == 401:
|
||||
@@ -763,6 +766,12 @@ class ConfluenceConnector(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if not first_space:
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
"there truly are no spaces in this Confluence instance."
|
||||
)
|
||||
|
||||
if self.space:
|
||||
try:
|
||||
self.low_timeout_confluence_client.get_space(self.space)
|
||||
@@ -771,12 +780,6 @@ class ConfluenceConnector(
|
||||
"Invalid Confluence space key provided"
|
||||
) from e
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
"there truly are no spaces in this Confluence instance."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -46,7 +46,6 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -63,6 +62,9 @@ _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
|
||||
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
|
||||
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -213,6 +215,97 @@ class OnyxConfluence:
|
||||
]
|
||||
return oauth2_dict
|
||||
|
||||
def _build_spaces_url(
|
||||
self,
|
||||
is_v2: bool,
|
||||
base_url: str,
|
||||
limit: int,
|
||||
space_keys: list[str] | None,
|
||||
start: int | None = None,
|
||||
) -> str:
|
||||
"""Build URL for Confluence spaces API with query parameters."""
|
||||
key_param = "keys" if is_v2 else "spaceKey"
|
||||
|
||||
params = [f"limit={limit}"]
|
||||
if space_keys:
|
||||
params.append(f"{key_param}={','.join(space_keys)}")
|
||||
if start is not None and not is_v2:
|
||||
params.append(f"start={start}")
|
||||
|
||||
return f"{base_url}?{'&'.join(params)}"
|
||||
|
||||
def _paginate_spaces_for_endpoint(
|
||||
self,
|
||||
is_v2: bool,
|
||||
base_url: str,
|
||||
limit: int,
|
||||
space_keys: list[str] | None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""Internal helper to paginate through spaces for a specific API endpoint."""
|
||||
start = 0
|
||||
url = self._build_spaces_url(
|
||||
is_v2, base_url, limit, space_keys, start if not is_v2 else None
|
||||
)
|
||||
|
||||
while url:
|
||||
response = self.get(url, advanced_mode=True)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = data.get("results", [])
|
||||
if not results:
|
||||
return
|
||||
|
||||
yield from results
|
||||
|
||||
if is_v2:
|
||||
url = data.get("_links", {}).get("next", "")
|
||||
else:
|
||||
if len(results) < limit:
|
||||
return
|
||||
start += len(results)
|
||||
url = self._build_spaces_url(is_v2, base_url, limit, space_keys, start)
|
||||
|
||||
def retrieve_confluence_spaces(
|
||||
self,
|
||||
space_keys: list[str] | None = None,
|
||||
limit: int = 50,
|
||||
) -> Iterator[dict[str, str]]:
|
||||
"""
|
||||
Retrieve spaces from Confluence using v2 API (Cloud) or v1 API (Server/fallback).
|
||||
|
||||
Args:
|
||||
space_keys: Optional list of space keys to filter by
|
||||
limit: Results per page (default 50)
|
||||
|
||||
Yields:
|
||||
Space dictionaries with keys: id, key, name, type, status, etc.
|
||||
|
||||
Note:
|
||||
For Cloud instances, attempts v2 API first. If v2 returns 404,
|
||||
automatically falls back to v1 API for compatibility with older instances.
|
||||
"""
|
||||
# Determine API version once
|
||||
use_v2 = self._is_cloud and not self.scoped_token
|
||||
base_url = _CONFLUENCE_SPACES_API_V2 if use_v2 else _CONFLUENCE_SPACES_API_V1
|
||||
|
||||
try:
|
||||
yield from self._paginate_spaces_for_endpoint(
|
||||
use_v2, base_url, limit, space_keys
|
||||
)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 404 and use_v2:
|
||||
logger.warning(
|
||||
"v2 spaces API returned 404, falling back to v1 API. "
|
||||
"This may indicate an older Confluence Cloud instance."
|
||||
)
|
||||
# Fallback to v1
|
||||
yield from self._paginate_spaces_for_endpoint(
|
||||
False, _CONFLUENCE_SPACES_API_V1, limit, space_keys
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _probe_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
@@ -226,11 +319,9 @@ class OnyxConfluence:
|
||||
if self.scoped_token:
|
||||
# v2 endpoint doesn't always work with scoped tokens, use v1
|
||||
token = credentials["confluence_access_token"]
|
||||
probe_url = f"{self.base_url}/rest/api/space?limit=1"
|
||||
probe_url = f"{self.base_url}/{_CONFLUENCE_SPACES_API_V1}?limit=1"
|
||||
import requests
|
||||
|
||||
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
probe_url,
|
||||
@@ -252,59 +343,23 @@ class OnyxConfluence:
|
||||
raise e
|
||||
return
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Probing Confluence with OAuth Access Token.")
|
||||
# Initialize connection with probe timeout settings
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **merged_kwargs
|
||||
)
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
|
||||
credentials
|
||||
)
|
||||
url = (
|
||||
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
)
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url, oauth2=oauth2_dict, **merged_kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
logger.info("running with cloud client")
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
else:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
# Retrieve first space to validate connection
|
||||
spaces_iter = self.retrieve_confluence_spaces(limit=1)
|
||||
first_space = next(spaces_iter, None)
|
||||
|
||||
# This call sometimes hangs indefinitely, so we run it in a timeout
|
||||
spaces = run_with_timeout(
|
||||
timeout=10,
|
||||
func=confluence_client_with_minimal_retries.get_all_spaces,
|
||||
limit=1,
|
||||
if not first_space:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {self._url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
logger.info("Confluence probe succeeded.")
|
||||
logger.info("Confluence probe succeeded.")
|
||||
|
||||
def _initialize_connection(
|
||||
self,
|
||||
|
||||
@@ -191,7 +191,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
"""If dynamic, the credentials may change during usage ... meaning the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
|
||||
@@ -644,6 +644,7 @@ class JiraConnector(
|
||||
jql=self.jql_query,
|
||||
start=0,
|
||||
max_results=1,
|
||||
all_issue_ids=[],
|
||||
)
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import contextlib
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@@ -10,28 +12,23 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi_users import exceptions
|
||||
from fastapi_users.authentication import Strategy
|
||||
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SAML_CONF_DIR
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.saml import expire_saml_account
|
||||
from onyx.db.saml import get_saml_account
|
||||
from onyx.db.saml import upsert_saml_account
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.secrets import encrypt_string
|
||||
from onyx.utils.secrets import extract_hashed_cookie
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -165,35 +162,63 @@ class SAMLAuthorizeResponse(BaseModel):
|
||||
authorization_url: str
|
||||
|
||||
|
||||
def _sanitize_relay_state(candidate: str | None) -> str | None:
|
||||
"""Ensure the relay state is an internal path to avoid open redirects."""
|
||||
if not candidate:
|
||||
return None
|
||||
|
||||
relay_state = candidate.strip()
|
||||
if not relay_state or not relay_state.startswith("/"):
|
||||
return None
|
||||
|
||||
if "\\" in relay_state:
|
||||
return None
|
||||
|
||||
# Reject colon before query/fragment to match frontend validation
|
||||
path_portion = relay_state.split("?", 1)[0].split("#", 1)[0]
|
||||
if ":" in path_portion:
|
||||
return None
|
||||
|
||||
parsed = urlparse(relay_state)
|
||||
if parsed.scheme or parsed.netloc:
|
||||
return None
|
||||
|
||||
return relay_state
|
||||
|
||||
|
||||
@router.get("/authorize")
|
||||
async def saml_login(request: Request) -> SAMLAuthorizeResponse:
|
||||
req = await prepare_from_fastapi_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
|
||||
callback_url = auth.login()
|
||||
return_to = _sanitize_relay_state(request.query_params.get("next"))
|
||||
callback_url = auth.login(return_to=return_to)
|
||||
return SAMLAuthorizeResponse(authorization_url=callback_url)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def saml_login_callback_get(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
strategy: Strategy[User, uuid.UUID] = Depends(auth_backend.get_strategy),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
) -> Response:
|
||||
"""Handle SAML callback via HTTP-Redirect binding (GET request)"""
|
||||
return await _process_saml_callback(request, db_session)
|
||||
return await _process_saml_callback(request, strategy, user_manager)
|
||||
|
||||
|
||||
@router.post("/callback")
|
||||
async def saml_login_callback(
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
strategy: Strategy[User, uuid.UUID] = Depends(auth_backend.get_strategy),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
) -> Response:
|
||||
"""Handle SAML callback via HTTP-POST binding (POST request)"""
|
||||
return await _process_saml_callback(request, db_session)
|
||||
return await _process_saml_callback(request, strategy, user_manager)
|
||||
|
||||
|
||||
async def _process_saml_callback(
|
||||
request: Request,
|
||||
db_session: Session,
|
||||
strategy: Strategy[User, uuid.UUID],
|
||||
user_manager: UserManager,
|
||||
) -> Response:
|
||||
req = await prepare_from_fastapi_request(request)
|
||||
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
|
||||
@@ -251,40 +276,19 @@ async def _process_saml_callback(
|
||||
|
||||
user = await upsert_saml_user(email=user_email)
|
||||
|
||||
# Generate a random session cookie and Sha256 encrypt before saving
|
||||
session_cookie = secrets.token_hex(16)
|
||||
saved_cookie = encrypt_string(session_cookie)
|
||||
|
||||
upsert_saml_account(user_id=user.id, cookie=saved_cookie, db_session=db_session)
|
||||
|
||||
# Redirect to main Onyx search page
|
||||
response = Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
response.set_cookie(
|
||||
key="session",
|
||||
value=session_cookie,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
response = await auth_backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def saml_logout(
|
||||
request: Request,
|
||||
async_db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> None:
|
||||
saved_cookie = extract_hashed_cookie(request)
|
||||
|
||||
if saved_cookie:
|
||||
saml_account = await get_saml_account(
|
||||
cookie=saved_cookie, async_db_session=async_db_session
|
||||
user_token: tuple[User, str] = Depends(
|
||||
fastapi_users.authenticator.current_user_token(
|
||||
active=True, verified=REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
if saml_account:
|
||||
await expire_saml_account(
|
||||
saml_account=saml_account, async_db_session=async_db_session
|
||||
)
|
||||
|
||||
return
|
||||
),
|
||||
strategy: Strategy[User, uuid.UUID] = Depends(auth_backend.get_strategy),
|
||||
) -> Response:
|
||||
user, token = user_token
|
||||
return await auth_backend.logout(strategy, user, token)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import hashlib
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from onyx.configs.constants import SESSION_KEY
|
||||
|
||||
|
||||
def encrypt_string(s: str) -> str:
|
||||
return hashlib.sha256(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def extract_hashed_cookie(request: Request) -> str | None:
|
||||
session_cookie = request.cookies.get(SESSION_KEY)
|
||||
return encrypt_string(session_cookie) if session_cookie else None
|
||||
@@ -53,6 +53,8 @@ def jira_connector_with_jql() -> JiraConnector:
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
connector.validate_connector_settings()
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -23,11 +23,11 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
|
||||
|
||||
class TestDBBase(DeclarativeBase):
|
||||
class DBBaseTest(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class FileRecord(TestDBBase):
|
||||
class FileRecord(DBBaseTest):
|
||||
__tablename__: str = "file_record"
|
||||
|
||||
# Internal file ID, must be unique across all files
|
||||
@@ -56,7 +56,7 @@ class FileRecord(TestDBBase):
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
"""Create an in-memory SQLite database for testing"""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
TestDBBase.metadata.create_all(engine)
|
||||
DBBaseTest.metadata.create_all(engine)
|
||||
SessionLocal = sessionmaker(bind=engine)
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -67,6 +66,7 @@ def confluence_connector(
|
||||
)
|
||||
# Initialize the client directly
|
||||
connector._confluence_client = mock_confluence_client
|
||||
connector._low_timeout_confluence_client = mock_confluence_client
|
||||
with patch("onyx.connectors.confluence.connector._SLIM_DOC_BATCH_SIZE", 2):
|
||||
yield connector
|
||||
|
||||
@@ -355,27 +355,32 @@ def test_validate_connector_settings_errors(
|
||||
"""Test validation with various error scenarios"""
|
||||
error = HTTPError(response=MagicMock(status_code=status_code))
|
||||
|
||||
confluence_client = MagicMock()
|
||||
confluence_connector._low_timeout_confluence_client = confluence_client
|
||||
get_all_spaces_mock = cast(MagicMock, confluence_client.get_all_spaces)
|
||||
get_all_spaces_mock.side_effect = error
|
||||
with patch(
|
||||
"onyx.connectors.confluence.onyx_confluence.OnyxConfluence.retrieve_confluence_spaces"
|
||||
) as mock_retrieve:
|
||||
mock_retrieve.side_effect = error
|
||||
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
confluence_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
with pytest.raises(expected_exception) as excinfo:
|
||||
confluence_connector.validate_connector_settings()
|
||||
assert expected_message in str(excinfo.value)
|
||||
|
||||
|
||||
def test_validate_connector_settings_success(
|
||||
confluence_connector: ConfluenceConnector,
|
||||
) -> None:
|
||||
"""Test successful validation"""
|
||||
confluence_client = MagicMock()
|
||||
confluence_connector._low_timeout_confluence_client = confluence_client
|
||||
get_all_spaces_mock = cast(MagicMock, confluence_client.get_all_spaces)
|
||||
get_all_spaces_mock.return_value = {"results": [{"key": "TEST"}]}
|
||||
|
||||
confluence_connector.validate_connector_settings()
|
||||
get_all_spaces_mock.assert_called_once()
|
||||
low_client = confluence_connector.low_timeout_confluence_client
|
||||
with patch.object(
|
||||
low_client, "retrieve_confluence_spaces", return_value=iter([{"key": "TEST"}])
|
||||
) as mock_retrieve, patch.object(
|
||||
low_client,
|
||||
"get_space",
|
||||
return_value={"key": "TEST"},
|
||||
create=True,
|
||||
) as mock_get_space:
|
||||
confluence_connector.validate_connector_settings()
|
||||
mock_retrieve.assert_called_once()
|
||||
mock_get_space.assert_called_once_with(confluence_connector.space)
|
||||
|
||||
|
||||
def test_checkpoint_progress(
|
||||
|
||||
@@ -39,7 +39,6 @@ class _FakeResponse:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_zendesk_client_per_minute_rate_limiting(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { validateInternalRedirect } from "@/lib/auth/redirectValidation";
|
||||
import { getDomain } from "@/lib/redirectSS";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
@@ -28,16 +29,21 @@ async function handleSamlCallback(
|
||||
},
|
||||
};
|
||||
|
||||
let relayState: string | null = null;
|
||||
|
||||
// For POST requests, include form data
|
||||
if (method === "POST") {
|
||||
fetchOptions.body = await request.formData();
|
||||
const formData = await request.formData();
|
||||
const relayStateValue = formData.get("RelayState");
|
||||
relayState = typeof relayStateValue === "string" ? relayStateValue : null;
|
||||
fetchOptions.body = formData;
|
||||
}
|
||||
|
||||
// OneLogin python toolkit only supports HTTP-POST binding for SAMLResponse.
|
||||
// If the IdP returned SAMLResponse via query parameters (GET), convert to POST.
|
||||
if (method === "GET") {
|
||||
const samlResponse = request.nextUrl.searchParams.get("SAMLResponse");
|
||||
const relayState = request.nextUrl.searchParams.get("RelayState");
|
||||
relayState = request.nextUrl.searchParams.get("RelayState");
|
||||
if (samlResponse) {
|
||||
const formData = new FormData();
|
||||
formData.set("SAMLResponse", samlResponse);
|
||||
@@ -61,8 +67,11 @@ async function handleSamlCallback(
|
||||
);
|
||||
}
|
||||
|
||||
const validatedRelayState = validateInternalRedirect(relayState);
|
||||
const redirectDestination = validatedRelayState ?? "/";
|
||||
|
||||
const redirectResponse = NextResponse.redirect(
|
||||
new URL("/", getDomain(request)),
|
||||
new URL(redirectDestination, getDomain(request)),
|
||||
SEE_OTHER_REDIRECT_STATUS
|
||||
);
|
||||
redirectResponse.headers.set("set-cookie", setCookieHeader);
|
||||
|
||||
Reference in New Issue
Block a user