Compare commits

...

2 Commits

Author SHA1 Message Date
Wenxi
fc0e084eef chore(hotfix): v2.1.2 fix jira and confluence connectors (#5967)
Co-authored-by: Evan Lohn <evan@danswer.ai>
2025-10-29 09:47:46 -07:00
Justin Tahara
460f19d2f0 chore(hotfix): Align Cookie Usage (#5954) (#5965) 2025-10-28 17:05:48 -07:00
16 changed files with 225 additions and 174 deletions

View File

@@ -1,6 +1,6 @@
# Copy this file to .env in the .vscode folder
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
# Also check out danswer/backend/scripts/restart_containers.sh for a script to restart the containers which Danswer relies on outside of VSCode/Cursor processes
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
@@ -37,8 +37,8 @@ OPENAI_API_KEY=<REPLACE THIS>
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o
# For Danswer Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using DanswerBot
# For Onyx Slack Bot, overrides the UI values so no need to set this up via UI every time
# Only needed if using OnyxBot
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
@@ -75,4 +75,9 @@ SHOW_EXTRA_CONNECTORS=True
LANGSMITH_TRACING="true"
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
LANGSMITH_API_KEY=<REPLACE_THIS>
LANGSMITH_PROJECT=<REPLACE_THIS>
LANGSMITH_PROJECT=<REPLACE_THIS>
# Local Confluence OAuth testing
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
# NEXT_PUBLIC_TEST_ENV=True

View File

@@ -139,19 +139,13 @@ def get_all_space_permissions(
) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance
all_space_keys = []
start = 0
while True:
spaces_batch = confluence_client.get_all_spaces(
start=start, limit=REQUEST_PAGINATION_LIMIT
all_space_keys = [
key
for space in confluence_client.retrieve_confluence_spaces(
limit=REQUEST_PAGINATION_LIMIT,
)
for space in spaces_batch.get("results", []):
all_space_keys.append(space.get("key"))
if len(spaces_batch.get("results", [])) < REQUEST_PAGINATION_LIMIT:
break
start += len(spaces_batch.get("results", []))
if (key := space.get("key"))
]
# Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")

View File

@@ -76,6 +76,7 @@ class ConfluenceCloudOAuth:
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"read:space:confluence%20"
"readonly:content.attachment:confluence%20"
"search:confluence%20"
# granular scope

View File

@@ -109,13 +109,11 @@ from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.saml import get_saml_account
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.secrets import extract_hashed_cookie
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -1064,17 +1062,7 @@ async def _check_for_saml_and_jwt(
user: User | None,
async_db_session: AsyncSession,
) -> User | None:
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None
# If user is still None, check for JWT in Authorization header
# If user is None, check for JWT in Authorization header
if user is None and JWT_PUBLIC_KEY_URL is not None:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):

View File

@@ -21,7 +21,6 @@ GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
PUBLIC_DOC_PAT = "PUBLIC"
ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
# Cookies
FASTAPI_USERS_AUTH_COOKIE_NAME = (

View File

@@ -744,7 +744,10 @@ class ConfluenceConnector(
def validate_connector_settings(self) -> None:
try:
spaces = self.low_timeout_confluence_client.get_all_spaces(limit=1)
spaces_iter = self.low_timeout_confluence_client.retrieve_confluence_spaces(
limit=1,
)
first_space = next(spaces_iter, None)
except HTTPError as e:
status_code = e.response.status_code if e.response else None
if status_code == 401:
@@ -763,6 +766,12 @@ class ConfluenceConnector(
f"Unexpected error while validating Confluence settings: {e}"
)
if not first_space:
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
"there truly are no spaces in this Confluence instance."
)
if self.space:
try:
self.low_timeout_confluence_client.get_space(self.space)
@@ -771,12 +780,6 @@ class ConfluenceConnector(
"Invalid Confluence space key provided"
) from e
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
"there truly are no spaces in this Confluence instance."
)
if __name__ == "__main__":
import os

View File

@@ -46,7 +46,6 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
@@ -63,6 +62,9 @@ _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
_CONFLUENCE_SPACES_API_V1 = "rest/api/space"
_CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces"
class ConfluenceRateLimitError(Exception):
pass
@@ -213,6 +215,97 @@ class OnyxConfluence:
]
return oauth2_dict
def _build_spaces_url(
self,
is_v2: bool,
base_url: str,
limit: int,
space_keys: list[str] | None,
start: int | None = None,
) -> str:
"""Build URL for Confluence spaces API with query parameters."""
key_param = "keys" if is_v2 else "spaceKey"
params = [f"limit={limit}"]
if space_keys:
params.append(f"{key_param}={','.join(space_keys)}")
if start is not None and not is_v2:
params.append(f"start={start}")
return f"{base_url}?{'&'.join(params)}"
def _paginate_spaces_for_endpoint(
self,
is_v2: bool,
base_url: str,
limit: int,
space_keys: list[str] | None,
) -> Iterator[dict[str, Any]]:
"""Internal helper to paginate through spaces for a specific API endpoint."""
start = 0
url = self._build_spaces_url(
is_v2, base_url, limit, space_keys, start if not is_v2 else None
)
while url:
response = self.get(url, advanced_mode=True)
response.raise_for_status()
data = response.json()
results = data.get("results", [])
if not results:
return
yield from results
if is_v2:
url = data.get("_links", {}).get("next", "")
else:
if len(results) < limit:
return
start += len(results)
url = self._build_spaces_url(is_v2, base_url, limit, space_keys, start)
def retrieve_confluence_spaces(
self,
space_keys: list[str] | None = None,
limit: int = 50,
) -> Iterator[dict[str, str]]:
"""
Retrieve spaces from Confluence using v2 API (Cloud) or v1 API (Server/fallback).
Args:
space_keys: Optional list of space keys to filter by
limit: Results per page (default 50)
Yields:
Space dictionaries with keys: id, key, name, type, status, etc.
Note:
For Cloud instances, attempts v2 API first. If v2 returns 404,
automatically falls back to v1 API for compatibility with older instances.
"""
# Determine API version once
use_v2 = self._is_cloud and not self.scoped_token
base_url = _CONFLUENCE_SPACES_API_V2 if use_v2 else _CONFLUENCE_SPACES_API_V1
try:
yield from self._paginate_spaces_for_endpoint(
use_v2, base_url, limit, space_keys
)
except HTTPError as e:
if e.response.status_code == 404 and use_v2:
logger.warning(
"v2 spaces API returned 404, falling back to v1 API. "
"This may indicate an older Confluence Cloud instance."
)
# Fallback to v1
yield from self._paginate_spaces_for_endpoint(
False, _CONFLUENCE_SPACES_API_V1, limit, space_keys
)
else:
raise
def _probe_connection(
self,
**kwargs: Any,
@@ -226,11 +319,9 @@ class OnyxConfluence:
if self.scoped_token:
# v2 endpoint doesn't always work with scoped tokens, use v1
token = credentials["confluence_access_token"]
probe_url = f"{self.base_url}/rest/api/space?limit=1"
probe_url = f"{self.base_url}/{_CONFLUENCE_SPACES_API_V1}?limit=1"
import requests
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
try:
r = requests.get(
probe_url,
@@ -252,59 +343,23 @@ class OnyxConfluence:
raise e
return
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Probing Confluence with OAuth Access Token.")
# Initialize connection with probe timeout settings
self._confluence = self._initialize_connection_helper(
credentials, **merged_kwargs
)
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
credentials
)
url = (
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
)
confluence_client_with_minimal_retries = Confluence(
url=url, oauth2=oauth2_dict, **merged_kwargs
)
else:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
logger.info("running with cloud client")
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**merged_kwargs,
)
else:
confluence_client_with_minimal_retries = Confluence(
url=url,
token=credentials["confluence_access_token"],
**merged_kwargs,
)
# Retrieve first space to validate connection
spaces_iter = self.retrieve_confluence_spaces(limit=1)
first_space = next(spaces_iter, None)
# This call sometimes hangs indefinitely, so we run it in a timeout
spaces = run_with_timeout(
timeout=10,
func=confluence_client_with_minimal_retries.get_all_spaces,
limit=1,
if not first_space:
raise RuntimeError(
f"No spaces found at {self._url}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {url}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
logger.info("Confluence probe succeeded.")
logger.info("Confluence probe succeeded.")
def _initialize_connection(
self,

View File

@@ -191,7 +191,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
"""If dynamic, the credentials may change during usage ... meaning the client
needs to use the locking features of the credentials provider to operate
correctly.

View File

@@ -644,6 +644,7 @@ class JiraConnector(
jql=self.jql_query,
start=0,
max_results=1,
all_issue_ids=[],
)
),
None,

View File

@@ -1,7 +1,9 @@
import contextlib
import secrets
import string
import uuid
from typing import Any
from urllib.parse import urlparse
from fastapi import APIRouter
from fastapi import Depends
@@ -10,28 +12,23 @@ from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi_users import exceptions
from fastapi_users.authentication import Strategy
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.users import auth_backend
from onyx.auth.users import fastapi_users
from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SAML_CONF_DIR
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.saml import expire_saml_account
from onyx.db.saml import get_saml_account
from onyx.db.saml import upsert_saml_account
from onyx.utils.logger import setup_logger
from onyx.utils.secrets import encrypt_string
from onyx.utils.secrets import extract_hashed_cookie
logger = setup_logger()
@@ -165,35 +162,63 @@ class SAMLAuthorizeResponse(BaseModel):
authorization_url: str
def _sanitize_relay_state(candidate: str | None) -> str | None:
"""Ensure the relay state is an internal path to avoid open redirects."""
if not candidate:
return None
relay_state = candidate.strip()
if not relay_state or not relay_state.startswith("/"):
return None
if "\\" in relay_state:
return None
# Reject colon before query/fragment to match frontend validation
path_portion = relay_state.split("?", 1)[0].split("#", 1)[0]
if ":" in path_portion:
return None
parsed = urlparse(relay_state)
if parsed.scheme or parsed.netloc:
return None
return relay_state
@router.get("/authorize")
async def saml_login(request: Request) -> SAMLAuthorizeResponse:
req = await prepare_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
callback_url = auth.login()
return_to = _sanitize_relay_state(request.query_params.get("next"))
callback_url = auth.login(return_to=return_to)
return SAMLAuthorizeResponse(authorization_url=callback_url)
@router.get("/callback")
async def saml_login_callback_get(
request: Request,
db_session: Session = Depends(get_session),
strategy: Strategy[User, uuid.UUID] = Depends(auth_backend.get_strategy),
user_manager: UserManager = Depends(get_user_manager),
) -> Response:
"""Handle SAML callback via HTTP-Redirect binding (GET request)"""
return await _process_saml_callback(request, db_session)
return await _process_saml_callback(request, strategy, user_manager)
@router.post("/callback")
async def saml_login_callback(
request: Request,
db_session: Session = Depends(get_session),
strategy: Strategy[User, uuid.UUID] = Depends(auth_backend.get_strategy),
user_manager: UserManager = Depends(get_user_manager),
) -> Response:
"""Handle SAML callback via HTTP-POST binding (POST request)"""
return await _process_saml_callback(request, db_session)
return await _process_saml_callback(request, strategy, user_manager)
async def _process_saml_callback(
request: Request,
db_session: Session,
strategy: Strategy[User, uuid.UUID],
user_manager: UserManager,
) -> Response:
req = await prepare_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)
@@ -251,40 +276,19 @@ async def _process_saml_callback(
user = await upsert_saml_user(email=user_email)
# Generate a random session cookie and Sha256 encrypt before saving
session_cookie = secrets.token_hex(16)
saved_cookie = encrypt_string(session_cookie)
upsert_saml_account(user_id=user.id, cookie=saved_cookie, db_session=db_session)
# Redirect to main Onyx search page
response = Response(status_code=status.HTTP_204_NO_CONTENT)
response.set_cookie(
key="session",
value=session_cookie,
httponly=True,
secure=True,
max_age=SESSION_EXPIRE_TIME_SECONDS,
)
response = await auth_backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
return response
@router.post("/logout")
async def saml_logout(
request: Request,
async_db_session: AsyncSession = Depends(get_async_session),
) -> None:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
user_token: tuple[User, str] = Depends(
fastapi_users.authenticator.current_user_token(
active=True, verified=REQUIRE_EMAIL_VERIFICATION
)
if saml_account:
await expire_saml_account(
saml_account=saml_account, async_db_session=async_db_session
)
return
),
strategy: Strategy[User, uuid.UUID] = Depends(auth_backend.get_strategy),
) -> Response:
user, token = user_token
return await auth_backend.logout(strategy, user, token)

View File

@@ -1,14 +0,0 @@
import hashlib
from fastapi import Request
from onyx.configs.constants import SESSION_KEY
def encrypt_string(s: str) -> str:
return hashlib.sha256(s.encode()).hexdigest()
def extract_hashed_cookie(request: Request) -> str | None:
session_cookie = request.cookies.get(SESSION_KEY)
return encrypt_string(session_cookie) if session_cookie else None

View File

@@ -53,6 +53,8 @@ def jira_connector_with_jql() -> JiraConnector:
"jira_api_token": os.environ["JIRA_API_TOKEN"],
}
)
connector.validate_connector_settings()
return connector

View File

@@ -23,11 +23,11 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import S3BackedFileStore
class TestDBBase(DeclarativeBase):
class DBBaseTest(DeclarativeBase):
pass
class FileRecord(TestDBBase):
class FileRecord(DBBaseTest):
__tablename__: str = "file_record"
# Internal file ID, must be unique across all files
@@ -56,7 +56,7 @@ class FileRecord(TestDBBase):
def db_session() -> Generator[Session, None, None]:
"""Create an in-memory SQLite database for testing"""
engine = create_engine("sqlite:///:memory:")
TestDBBase.metadata.create_all(engine)
DBBaseTest.metadata.create_all(engine)
SessionLocal = sessionmaker(bind=engine)
session = SessionLocal()
yield session

View File

@@ -4,7 +4,6 @@ from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -67,6 +66,7 @@ def confluence_connector(
)
# Initialize the client directly
connector._confluence_client = mock_confluence_client
connector._low_timeout_confluence_client = mock_confluence_client
with patch("onyx.connectors.confluence.connector._SLIM_DOC_BATCH_SIZE", 2):
yield connector
@@ -355,27 +355,32 @@ def test_validate_connector_settings_errors(
"""Test validation with various error scenarios"""
error = HTTPError(response=MagicMock(status_code=status_code))
confluence_client = MagicMock()
confluence_connector._low_timeout_confluence_client = confluence_client
get_all_spaces_mock = cast(MagicMock, confluence_client.get_all_spaces)
get_all_spaces_mock.side_effect = error
with patch(
"onyx.connectors.confluence.onyx_confluence.OnyxConfluence.retrieve_confluence_spaces"
) as mock_retrieve:
mock_retrieve.side_effect = error
with pytest.raises(expected_exception) as excinfo:
confluence_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
with pytest.raises(expected_exception) as excinfo:
confluence_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
def test_validate_connector_settings_success(
confluence_connector: ConfluenceConnector,
) -> None:
"""Test successful validation"""
confluence_client = MagicMock()
confluence_connector._low_timeout_confluence_client = confluence_client
get_all_spaces_mock = cast(MagicMock, confluence_client.get_all_spaces)
get_all_spaces_mock.return_value = {"results": [{"key": "TEST"}]}
confluence_connector.validate_connector_settings()
get_all_spaces_mock.assert_called_once()
low_client = confluence_connector.low_timeout_confluence_client
with patch.object(
low_client, "retrieve_confluence_spaces", return_value=iter([{"key": "TEST"}])
) as mock_retrieve, patch.object(
low_client,
"get_space",
return_value={"key": "TEST"},
create=True,
) as mock_get_space:
confluence_connector.validate_connector_settings()
mock_retrieve.assert_called_once()
mock_get_space.assert_called_once_with(confluence_connector.space)
def test_checkpoint_progress(

View File

@@ -39,7 +39,6 @@ class _FakeResponse:
return None
@pytest.mark.unit
def test_zendesk_client_per_minute_rate_limiting(
monkeypatch: pytest.MonkeyPatch,
) -> None:

View File

@@ -1,3 +1,4 @@
import { validateInternalRedirect } from "@/lib/auth/redirectValidation";
import { getDomain } from "@/lib/redirectSS";
import { buildUrl } from "@/lib/utilsSS";
import { NextRequest, NextResponse } from "next/server";
@@ -28,16 +29,21 @@ async function handleSamlCallback(
},
};
let relayState: string | null = null;
// For POST requests, include form data
if (method === "POST") {
fetchOptions.body = await request.formData();
const formData = await request.formData();
const relayStateValue = formData.get("RelayState");
relayState = typeof relayStateValue === "string" ? relayStateValue : null;
fetchOptions.body = formData;
}
// OneLogin python toolkit only supports HTTP-POST binding for SAMLResponse.
// If the IdP returned SAMLResponse via query parameters (GET), convert to POST.
if (method === "GET") {
const samlResponse = request.nextUrl.searchParams.get("SAMLResponse");
const relayState = request.nextUrl.searchParams.get("RelayState");
relayState = request.nextUrl.searchParams.get("RelayState");
if (samlResponse) {
const formData = new FormData();
formData.set("SAMLResponse", samlResponse);
@@ -61,8 +67,11 @@ async function handleSamlCallback(
);
}
const validatedRelayState = validateInternalRedirect(relayState);
const redirectDestination = validatedRelayState ?? "/";
const redirectResponse = NextResponse.redirect(
new URL("/", getDomain(request)),
new URL(redirectDestination, getDomain(request)),
SEE_OTHER_REDIRECT_STATUS
);
redirectResponse.headers.set("set-cookie", setCookieHeader);