Compare commits

..

3 Commits

Author SHA1 Message Date
pablodanswer
6ff78e077d nit 2024-12-06 12:57:43 -08:00
pablodanswer
c01512f846 fix slackbot 2024-12-06 12:56:46 -08:00
rkuo-danswer
7a3c06c2d2 first cut at slack oauth flow (#3323)
* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* cleanup

* comment work in progress

* move some stuff to ee, add some playwright tests for the oauth callback edge cases

* fix ee, fix test name

* fix tests

* code review fixes
2024-12-06 19:55:21 +00:00
21 changed files with 860 additions and 85 deletions

View File

@@ -0,0 +1,40 @@
"""non-nullbale slack bot id in channel config
Revision ID: f7a894b06d02
Revises: 9f696734098f
Create Date: 2024-12-06 12:55:42.845723
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7a894b06d02"
down_revision = "9f696734098f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Delete all rows with null slack_bot_id
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
# Make slack_bot_id non-nullable
op.alter_column(
"slack_channel_config",
"slack_bot_id",
existing_type=sa.Integer(),
nullable=False,
)
def downgrade() -> None:
# Make slack_bot_id nullable again
op.alter_column(
"slack_channel_config",
"slack_bot_id",
existing_type=sa.Integer(),
nullable=True,
)

View File

@@ -219,7 +219,7 @@ def connector_permission_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,

View File

@@ -81,6 +81,12 @@ OAUTH_CLIENT_SECRET = (
or ""
)
# for future OAuth connector support
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# for basic auth

View File

@@ -248,7 +248,6 @@ def create_credential(
)
db_session.commit()
return credential

View File

@@ -1490,7 +1490,9 @@ class SlackChannelConfig(Base):
__tablename__ = "slack_channel_config"
id: Mapped[int] = mapped_column(primary_key=True)
slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True)
slack_bot_id: Mapped[int] = mapped_column(
ForeignKey("slack_bot.id"), nullable=False
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)

View File

@@ -105,7 +105,6 @@ from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()

View File

@@ -11,14 +11,6 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml
#####
# Auto Permission Sync
#####
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -36,3 +28,6 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")

View File

@@ -10,9 +10,6 @@ from danswer.access.utils import prefix_group_w_source
from danswer.configs.constants import DocumentSource
from danswer.db.models import User__ExternalUserGroupId
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.utils.logger import setup_logger
logger = setup_logger()
class ExternalUserGroup(BaseModel):
@@ -76,13 +73,7 @@ def replace_user__ext_group_for_cc_pair(
new_external_permissions = []
for external_group in group_defs:
for user_email in external_group.user_emails:
user_id = email_id_map.get(user_email)
if user_id is None:
logger.warning(
f"User in group {external_group.id}"
f" with email {user_email} not found"
)
continue
user_id = email_id_map[user_email]
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,

View File

@@ -195,7 +195,6 @@ def _fetch_all_page_restrictions_for_space(
confluence_client: OnyxConfluence,
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -223,50 +222,29 @@ def _fetch_all_page_restrictions_for_space(
continue
space_key = slim_doc.perm_sync_data.get("space_key")
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
logger.debug(
f"Individually fetching space permissions for space {space_key}"
)
try:
# If the space permissions are not in the cache, then fetch them
if is_cloud:
retrieved_space_permissions = _get_cloud_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
else:
retrieved_space_permissions = _get_server_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
space_permissions_by_space_key[space_key] = retrieved_space_permissions
space_permissions = retrieved_space_permissions
except Exception as e:
logger.warning(
f"Error fetching space permissions for space {space_key}: {e}"
if space_permissions := space_permissions_by_space_key.get(space_key):
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
if not space_permissions:
logger.warning(
f"No permissions found for document {slim_doc.id} in space {space_key}"
)
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"Permissions are empty for document: {slim_doc.id}\n"
"This means space permissions are may be wrong for"
f" Space key: {space_key}"
)
continue
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
logger.warning(
f"No permissions found for document {slim_doc.id} in space {space_key}"
)
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"Permissions are empty for document: {slim_doc.id}\n"
"This means space permissions are may be wrong for"
f" Space key: {space_key}"
)
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
@@ -305,5 +283,4 @@ def confluence_doc_sync(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
)

View File

@@ -3,8 +3,6 @@ from collections.abc import Callable
from danswer.access.models import DocExternalAccess
from danswer.configs.constants import DocumentSource
from danswer.db.models import ConnectorCredentialPair
from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
@@ -58,7 +56,7 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
# If nothing is specified here, we run the doc_sync every time the celery beat runs
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: 5 * 60,
DocumentSource.SLACK: 5 * 60,
}
@@ -66,7 +64,7 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: 30 * 60,
}

View File

@@ -26,6 +26,7 @@ from ee.danswer.server.enterprise_settings.api import (
)
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.danswer.server.oauth import router as oauth_router
from ee.danswer.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -119,6 +120,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(
application, enterprise_settings_admin_router

View File

@@ -0,0 +1,423 @@
import base64
import uuid
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.db.credentials import create_credential
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CredentialBase
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
# Work in progress
# class ConfluenceCloudOAuth:
# """work in progress"""
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# class OAuthSession(BaseModel):
# """Stored in redis to be looked up on callback"""
# email: str
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
# CONFLUENCE_OAUTH_SCOPE = (
# "read:confluence-props%20"
# "read:confluence-content.all%20"
# "read:confluence-content.summary%20"
# "read:confluence-content.permission%20"
# "read:confluence-user%20"
# "read:confluence-groups%20"
# "readonly:content.attachment:confluence"
# )
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# # eventually for Confluence Data Center
# # oauth_url = (
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# # f"&redirect_uri={redirectme_uri}"
# # )
# @classmethod
# def generate_oauth_url(cls, state: str) -> str:
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
# @classmethod
# def generate_dev_oauth_url(cls, state: str) -> str:
# """dev mode workaround for localhost testing
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
# """
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
# @classmethod
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# url = (
# "https://auth.atlassian.com/authorize"
# f"?audience=api.atlassian.com"
# f"&client_id={cls.CLIENT_ID}"
# f"&redirect_uri={redirect_uri}"
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
# f"&state={state}"
# "&response_type=code"
# "&prompt=consent"
# )
# return url
# @classmethod
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
# """Temporary state to store in redis. to be looked up on auth response.
# Returns a json string.
# """
# session = ConfluenceCloudOAuth.OAuthSession(
# email=email, redirect_on_success=redirect_on_success
# )
# return session.model_dump_json()
# @classmethod
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
# return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
# elif connector == DocumentSource.GOOGLE_DRIVE:
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )

View File

@@ -9,9 +9,9 @@ import { AdminPageTitle } from "@/components/admin/Title";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useFormContext } from "@/components/context/FormContext";
import { getSourceDisplayName } from "@/lib/sources";
import { getSourceDisplayName, getSourceMetadata } from "@/lib/sources";
import { SourceIcon } from "@/components/SourceIcon";
import { useState } from "react";
import { useEffect, useState } from "react";
import { deleteCredential, linkCredential } from "@/lib/credential";
import { submitFiles } from "./pages/utils/files";
import { submitGoogleSite } from "./pages/utils/google_site";
@@ -43,6 +43,8 @@ import { Formik } from "formik";
import NavigationRow from "./NavigationRow";
import { useRouter } from "next/navigation";
import CardSection from "@/components/admin/CardSection";
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
export interface AdvancedConfig {
refreshFreq: number;
pruneFreq: number;
@@ -110,6 +112,23 @@ export default function AddConnector({
}: {
connector: ConfigurableSources;
}) {
const [currentPageUrl, setCurrentPageUrl] = useState<string | null>(null);
const [oauthUrl, setOauthUrl] = useState<string | null>(null);
const [isAuthorizing, setIsAuthorizing] = useState(false);
const [isAuthorizeVisible, setIsAuthorizeVisible] = useState(false);
useEffect(() => {
if (typeof window !== "undefined") {
setCurrentPageUrl(window.location.href);
}
if (EE_ENABLED && NEXT_PUBLIC_CLOUD_ENABLED) {
const sourceMetadata = getSourceMetadata(connector);
if (sourceMetadata?.oauthSupported == true) {
setIsAuthorizeVisible(true);
}
}
}, []);
const router = useRouter();
// State for managing credentials and files
@@ -135,8 +154,13 @@ export default function AddConnector({
const configuration: ConnectionConfiguration = connectorConfigs[connector];
// Form context and popup management
const { setFormStep, setAlowCreate, formStep, nextFormStep, prevFormStep } =
useFormContext();
const {
setFormStep,
setAllowCreate: setAllowCreate,
formStep,
nextFormStep,
prevFormStep,
} = useFormContext();
const { popup, setPopup } = usePopup();
// Hooks for Google Drive and Gmail credentials
@@ -192,7 +216,7 @@ export default function AddConnector({
const onSwap = async (selectedCredential: Credential<any>) => {
setCurrentCredential(selectedCredential);
setAlowCreate(true);
setAllowCreate(true);
setPopup({
message: "Swapped credential successfully!",
type: "success",
@@ -204,6 +228,37 @@ export default function AddConnector({
router.push("/admin/indexing/status?message=connector-created");
};
const handleAuthorize = async () => {
// authorize button handler
// gets an auth url from the server and directs the user to it in a popup
if (!currentPageUrl) return;
setIsAuthorizing(true);
try {
const response = await prepareOAuthAuthorizationRequest(
connector,
currentPageUrl
);
if (response.url) {
setOauthUrl(response.url);
window.open(response.url, "_blank", "noopener,noreferrer");
} else {
setPopup({ message: "Failed to fetch OAuth URL", type: "error" });
}
} catch (error: unknown) {
// Narrow the type of error
if (error instanceof Error) {
setPopup({ message: `Error: ${error.message}`, type: "error" });
} else {
// Handle non-standard errors
setPopup({ message: "An unknown error occurred", type: "error" });
}
} finally {
setIsAuthorizing(false);
}
};
return (
<Formik
initialValues={{
@@ -385,16 +440,31 @@ export default function AddConnector({
onSwitch={onSwap}
/>
{!createConnectorToggle && (
<button
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded"
onClick={() =>
setCreateConnectorToggle(
(createConnectorToggle) => !createConnectorToggle
)
}
>
Create New
</button>
<div className="mt-6 flex space-x-4">
{/* Button to pop up a form to manually enter credentials */}
<button
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
onClick={() =>
setCreateConnectorToggle(
(createConnectorToggle) => !createConnectorToggle
)
}
>
Create New
</button>
{/* Button to sign in via OAuth */}
<button
onClick={handleAuthorize}
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
disabled={isAuthorizing}
hidden={!isAuthorizeVisible}
>
{isAuthorizing
? "Authorizing..."
: `Authorize with ${getSourceDisplayName(connector)}`}
</button>
</div>
)}
{/* NOTE: connector will never be google_drive, since the ternary above will

View File

@@ -0,0 +1,111 @@
"use client";
import { useEffect, useState } from "react";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { AdminPageTitle } from "@/components/admin/Title";
import { Button } from "@/components/ui/button";
import Title from "@/components/ui/title";
import { KeyIcon } from "@/components/icons/icons";
import { getSourceMetadata, isValidSource } from "@/lib/sources";
import { ValidSources } from "@/lib/types";
import CardSection from "@/components/admin/CardSection";
import { handleOAuthAuthorizationResponse } from "@/lib/oauth_utils";
export default function OAuthCallbackPage() {
const router = useRouter();
const searchParams = useSearchParams();
const [statusMessage, setStatusMessage] = useState("Processing...");
const [statusDetails, setStatusDetails] = useState(
"Please wait while we complete the setup."
);
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
const [isError, setIsError] = useState(false);
const [pageTitle, setPageTitle] = useState(
"Authorize with Third-Party service"
);
// Extract query parameters
const code = searchParams.get("code");
const state = searchParams.get("state");
const pathname = usePathname();
const connector = pathname?.split("/")[3];
useEffect(() => {
const handleOAuthCallback = async () => {
if (!code || !state) {
setStatusMessage("Improperly formed OAuth authorization request.");
setStatusDetails(
!code ? "Missing authorization code." : "Missing state parameter."
);
setIsError(true);
return;
}
if (!connector || !isValidSource(connector)) {
setStatusMessage(
`The specified connector source type ${connector} does not exist.`
);
setStatusDetails(`${connector} is not a valid source type.`);
setIsError(true);
return;
}
const sourceMetadata = getSourceMetadata(connector as ValidSources);
setPageTitle(`Authorize with ${sourceMetadata.displayName}`);
setStatusMessage("Processing...");
setStatusDetails("Please wait while we complete authorization.");
setIsError(false); // Ensure no error state during loading
try {
const response = await handleOAuthAuthorizationResponse(code, state);
if (!response) {
throw new Error("Empty response from OAuth server.");
}
setStatusMessage("Success!");
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
);
setRedirectUrl(response.redirect_on_success); // Extract the redirect URL
setIsError(false);
} catch (error) {
console.error("OAuth error:", error);
setStatusMessage("Oops, something went wrong!");
setStatusDetails(
"An error occurred during the OAuth process. Please try again."
);
setIsError(true);
}
};
handleOAuthCallback();
}, [code, state, connector]);
return (
<div className="container mx-auto py-8">
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
<div className="flex flex-col items-center justify-center min-h-screen">
<CardSection className="max-w-md">
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
<p className="text-text-500">{statusDetails}</p>
{redirectUrl && !isError && (
<div className="mt-4">
<p className="text-sm">
Click{" "}
<a href={redirectUrl} className="text-blue-500 underline">
here
</a>{" "}
to continue.
</p>
</div>
)}
</CardSection>
</div>
</div>
);
}

View File

@@ -20,7 +20,7 @@ interface FormContextType {
allowAdvanced: boolean;
setAllowAdvanced: React.Dispatch<React.SetStateAction<boolean>>;
allowCreate: boolean;
setAlowCreate: React.Dispatch<React.SetStateAction<boolean>>;
setAllowCreate: React.Dispatch<React.SetStateAction<boolean>>;
}
const FormContext = createContext<FormContextType | undefined>(undefined);
@@ -39,7 +39,7 @@ export const FormProvider: React.FC<{
const [formValues, setFormValues] = useState<Record<string, any>>({});
const [allowAdvanced, setAllowAdvanced] = useState(false);
const [allowCreate, setAlowCreate] = useState(false);
const [allowCreate, setAllowCreate] = useState(false);
const nextFormStep = (values = "") => {
setFormStep((prevStep) => prevStep + 1);
@@ -88,7 +88,7 @@ export const FormProvider: React.FC<{
allowAdvanced,
setAllowAdvanced,
allowCreate,
setAlowCreate,
setAllowCreate,
};
return (

View File

@@ -1,6 +1,7 @@
"use client";
import {
ConnectorIndexingStatus,
OAuthSlackCallbackResponse,
DocumentBoostStatus,
Tag,
UserGroup,

View File

@@ -0,0 +1,80 @@
import {
OAuthPrepareAuthorizationResponse,
OAuthSlackCallbackResponse,
} from "./types";
// server side handler to help initiate the oauth authorization request
export async function prepareOAuthAuthorizationRequest(
connector: string,
finalRedirect: string | null // a redirect (not the oauth redirect) for the user to return to after oauth is complete)
): Promise<OAuthPrepareAuthorizationResponse> {
let url = `/api/oauth/prepare-authorization-request?connector=${encodeURIComponent(
connector
)}`;
// Conditionally append the `redirect_on_success` parameter
if (finalRedirect) {
url += `&redirect_on_success=${encodeURIComponent(finalRedirect)}`;
}
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
connector: connector,
redirect_on_success: finalRedirect,
}),
});
if (!response.ok) {
throw new Error(
`Failed to prepare OAuth authorization request: ${response.status}`
);
}
// Parse the JSON response
const data = (await response.json()) as OAuthPrepareAuthorizationResponse;
return data;
}
// server side handler to process the oauth redirect callback
// https://api.slack.com/authentication/oauth-v2#exchanging
export async function handleOAuthAuthorizationResponse(
code: string,
state: string
): Promise<OAuthSlackCallbackResponse> {
const url = `/api/oauth/connector/slack/callback?code=${encodeURIComponent(
code
)}&state=${encodeURIComponent(state)}`;
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ code, state }),
});
if (!response.ok) {
let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`;
try {
const responseBody = await response.text(); // Read the body as text
errorDetails += `\nResponse Body: ${responseBody}`;
} catch (err) {
if (err instanceof Error) {
errorDetails += `\nUnable to read response body: ${err.message}`;
} else {
errorDetails += `\nUnable to read response body: Unknown error type`;
}
}
throw new Error(errorDetails);
}
// Parse the JSON response
const data = (await response.json()) as OAuthSlackCallbackResponse;
return data;
}

View File

@@ -124,6 +124,7 @@ export interface SourceMetadata {
shortDescription?: string;
internalName: ValidSources;
adminUrl: string;
oauthSupported?: boolean;
}
export interface SearchDefaultOverrides {

View File

@@ -76,6 +76,7 @@ export const SOURCE_METADATA_MAP: SourceMap = {
displayName: "Slack",
category: SourceCategory.Messaging,
docs: "https://docs.danswer.dev/connectors/slack",
oauthSupported: true,
},
gmail: {
icon: GmailIcon,
@@ -341,6 +342,7 @@ export function listSourceMetadata(): SourceMetadata[] {
export function getSourceDocLink(sourceType: ValidSources): string | null {
return SOURCE_METADATA_MAP[sourceType].docs || null;
}
export const isValidSource = (sourceType: string) => {
return Object.keys(SOURCE_METADATA_MAP).includes(sourceType);
};

View File

@@ -135,6 +135,18 @@ export interface ConnectorIndexingStatus<
in_progress: boolean;
}
export interface OAuthPrepareAuthorizationResponse {
url: string;
}
export interface OAuthSlackCallbackResponse {
success: boolean;
message: string;
team_id: string;
authed_user_id: string;
redirect_on_success: string;
}
export interface CCPairBasicInfo {
has_successful_run: boolean;
source: ValidSources;

View File

@@ -0,0 +1,65 @@
import { test, expect } from "@chromatic-com/playwright";
test(
"Admin - OAuth Redirect - Missing Code",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?state=xyz"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"Missing authorization code."
);
}
);
test(
"Admin - OAuth Redirect - Missing State",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"Missing state parameter."
);
}
);
test(
"Admin - OAuth Redirect - Invalid Connector",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/invalid-connector/oauth/callback?code=123&state=xyz"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"invalid-connector is not a valid source type."
);
}
);
test(
"Admin - OAuth Redirect - No Session",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
await page.goto(
"http://localhost:3000/admin/connectors/slack/oauth/callback?code=123&state=xyz"
);
await expect(page.locator("p.text-text-500")).toHaveText(
"An error occurred during the OAuth process. Please try again."
);
}
);