Compare commits

..

14 Commits

Author SHA1 Message Date
pablonyx
aacdf775da add basic user invite flow 2025-03-11 11:14:52 -07:00
pablonyx
59a388ce0a fix tests 2025-03-11 11:12:35 -07:00
rkuo-danswer
9cd3cbb978 fix versions (#4250)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-10 23:50:07 -07:00
pablonyx
ab1b6b487e descrease model server logspam (#4166) 2025-03-10 18:29:27 +00:00
Chris Weaver
6ead9510a4 Small notion tweaks (#4244)
* Small notion tweaks

* Add comment
2025-03-10 15:51:12 +00:00
Chris Weaver
965f9e98bf Eliminate extremely long log line for large checkpointds (#4236)
* Eliminate extremely long log line for large checkpointds

* address greptile
2025-03-10 15:50:50 +00:00
rkuo-danswer
426883bbf5 Feature/agentic buffered (#4231)
* rename agent test script to prevent pytest autodiscovery

* first cut

* fix log message

* fix up typing

* add a sample test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 15:48:42 +00:00
rkuo-danswer
6ca400ced9 Bugfix/delete document tags slow (#4232)
* Add Missing Date and Message-ID Headers to Ensure Email Delivery

* fix issue Performance issue during connector deletion #4191

* fix ruff

* bump to rebuild PR

---------

Co-authored-by: ThomaciousD <2194608+ThomaciousD@users.noreply.github.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 03:07:30 +00:00
Weves
104c4b9f4d small modal improvement 2025-03-09 20:54:53 -07:00
pablonyx
8b5e8bd5b9 k (#4240) 2025-03-10 03:06:13 +00:00
Weves
7f7621d7c0 SMall gitbook tweaks 2025-03-09 14:46:44 -07:00
pablonyx
06dcc28d05 Improved login experience (#4178)
* functional initial auth modal

* k

* k

* k

* looking good

* k

* k

* k

* k

* update

* k

* k

* misc bunch

* improvements

* k

* address comments

* k

* nit

* update

* k
2025-03-09 01:06:20 +00:00
pablonyx
18df63dfd9 Fix local background jobs (#4241) 2025-03-08 14:47:56 -08:00
Chris Weaver
0d3c72acbf Add basic memory logging (#4234)
* Add basic memory logging

* Small tweaks

* Switch to monotonic
2025-03-08 03:49:47 +00:00
87 changed files with 2634 additions and 582 deletions

View File

@@ -48,6 +48,8 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:

View File

@@ -114,3 +114,4 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -0,0 +1,51 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -1,10 +1,14 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -89,6 +99,12 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -97,6 +113,15 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -113,11 +138,104 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,3 +1,5 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -88,6 +91,64 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -0,0 +1,98 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,269 +1,24 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)

View File

@@ -0,0 +1,96 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,3 +67,30 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -4,6 +4,7 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -14,6 +15,7 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -353,3 +355,47 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None

View File

@@ -0,0 +1,67 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -0,0 +1,39 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,27 +1,56 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = logging.getLogger(__name__)
logger = setup_logger()
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -41,7 +70,9 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
db_session.add(
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
)
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
@@ -76,3 +107,187 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -62,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -92,31 +146,17 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -155,7 +195,6 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -239,22 +278,51 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
@@ -569,6 +637,13 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -153,7 +153,8 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
msg["From"] = mail_from
if mail_from:
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,5 +1,6 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -18,3 +19,17 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,6 +100,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -894,7 +895,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
async def current_chat_accessible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1095,6 +1096,12 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1126,9 +1133,14 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1140,6 +1152,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -5,40 +5,53 @@ from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()

View File

@@ -1,10 +1,13 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -44,9 +47,44 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -336,6 +374,8 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -347,13 +387,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = (
AgentSearchPacket = Union[
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
)
]
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -76,6 +76,7 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"

View File

@@ -228,10 +228,15 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -1,3 +1,4 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -204,6 +205,15 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,4 +1,3 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -32,6 +31,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": self.batch_size,
"page_size": _NOTION_PAGE_SIZE,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": self.batch_size,
"page_size": _NOTION_PAGE_SIZE,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}

View File

@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
@@ -706,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified, verify each is accessible
if self.channels:
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,

View File

@@ -2295,15 +2295,14 @@ class PublicBase(DeclarativeBase):
__abstract__ = True
# Strictly keeps track of the tenant that a given user will authenticate to.
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
__table_args__ = ({"schema": "public"},)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@validates("email")
def validate_email(self, key: str, value: str) -> str:

View File

@@ -1,6 +1,5 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -149,11 +148,10 @@ def delete_document_tags_for_documents__no_commit(
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
orphan_tags_query = (
select(Tag.id)
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
.group_by(Tag.id)
.having(func.count(Document__Tag.document_id) == 0)
orphan_tags_query = select(Tag.id).where(
~db_session.query(Document__Tag.tag_id)
.filter(Document__Tag.tag_id == Tag.id)
.exists()
)
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()

View File

@@ -234,6 +234,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()

View File

@@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -112,7 +112,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == current_chat_accessible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):

View File

@@ -17,7 +17,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
@@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
@router.get("/connector-status")
def get_basic_connector_indexing_status(
user: User = Depends(current_chat_accesssible_user),
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs_for_user(

View File

@@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -390,7 +390,7 @@ def get_image_generation_tool(
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),

View File

@@ -7,7 +7,7 @@ from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
@@ -191,7 +191,7 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [

View File

@@ -53,6 +53,16 @@ class UserPreferences(BaseModel):
temperature_override_enabled: bool | None = None
class TenantSnapshot(BaseModel):
tenant_id: str
number_of_users: int
class TenantInfo(BaseModel):
invitation: TenantSnapshot | None = None
new_tenant: TenantSnapshot | None = None
class UserInfo(BaseModel):
id: str
email: str
@@ -65,9 +75,10 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
organization_name: str | None = None
team_name: str | None = None
is_anonymous_user: bool | None = None
password_configured: bool | None = None
tenant_info: TenantInfo | None = None
@classmethod
def from_model(
@@ -76,8 +87,9 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
organization_name: str | None = None,
team_name: str | None = None,
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -99,7 +111,7 @@ class UserInfo(BaseModel):
temperature_override_enabled=user.temperature_override_enabled,
)
),
organization_name=organization_name,
team_name=team_name,
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
# where they previously had this set + used OIDC, and now they switched to
# basic auth are now constantly getting redirected back to the login page
@@ -109,6 +121,7 @@ class UserInfo(BaseModel):
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
is_anonymous_user=is_anonymous_user,
tenant_info=tenant_info,
)

View File

@@ -12,13 +12,11 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import SUPER_USERS
@@ -55,6 +53,8 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import TenantInfo
from onyx.server.manage.models import TenantSnapshot
from onyx.server.manage.models import UserByEmail
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPreferences
@@ -296,13 +296,6 @@ def bulk_invite_users(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
)(new_invited_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
status_code=400,
detail="User has already been invited to a Onyx organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
@@ -425,6 +418,10 @@ async def delete_user(
db_session.expunge(user_to_delete)
try:
tenant_id = get_current_tenant_id()
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
delete_user_from_db(user_to_delete, db_session)
logger.info(f"Deleted user {user_to_delete.email}")
@@ -553,8 +550,8 @@ def verify_user_logged_in(
if anonymous_user_enabled(tenant_id=tenant_id):
store = get_kv_store()
return fetch_no_auth_user(store, anonymous_user_enabled=True)
raise BasicAuthenticationError(detail="User Not Authenticated")
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
@@ -563,16 +560,35 @@ def verify_user_logged_in(
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
)
organization_name = fetch_ee_implementation_or_noop(
team_name = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user.email)
new_tenant: TenantSnapshot | None = None
tenant_invitation: TenantSnapshot | None = None
if MULTI_TENANT:
if team_name != get_current_tenant_id():
user_count = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_count", None
)(team_name)
new_tenant = TenantSnapshot(tenant_id=team_name, number_of_users=user_count)
tenant_invitation = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
)(user.email)
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
organization_name=organization_name,
team_name=team_name,
tenant_info=TenantInfo(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
)
return user_info

View File

@@ -49,9 +49,9 @@ class FullUserSnapshot(BaseModel):
)
class InvitedUserSnapshot(BaseModel):
email: str
class DisplayPriorityRequest(BaseModel):
display_priority_map: dict[int, int]
class InvitedUserSnapshot(BaseModel):
email: str

View File

@@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import extract_headers
@@ -190,7 +190,7 @@ def update_chat_session_model(
def get_chat_session(
session_id: UUID,
is_shared: bool = False,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> ChatSessionDetailResponse:
user_id = user.id if user is not None else None
@@ -246,7 +246,7 @@ def get_chat_session(
@router.post("/create-chat-session")
def create_new_chat_session(
chat_session_creation_request: ChatSessionCreationRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> CreateChatSessionID:
user_id = user.id if user is not None else None
@@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
) -> StreamingResponse:
@@ -473,7 +473,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None

View File

@@ -11,7 +11,7 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
def check_token_rate_limits(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time

View File

@@ -32,15 +32,15 @@ class InCodeToolInfo(TypedDict):
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=SearchTool,
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
description="The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
in_code_tool_id=SearchTool.__name__,
display_name=SearchTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=ImageGenerationTool,
description=(
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
"The tool will be used when the user asks the assistant to generate an image."
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
"The action will be used when the user asks the assistant to generate an image."
),
in_code_tool_id=ImageGenerationTool.__name__,
display_name=ImageGenerationTool._DISPLAY_NAME,
@@ -51,7 +51,7 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=InternetSearchTool,
description=(
"The Internet Search Tool allows the assistant "
"The Internet Search Action allows the assistant "
"to perform internet searches for up-to-date information."
),
in_code_tool_id=InternetSearchTool.__name__,
@@ -98,7 +98,7 @@ def load_builtin_tools(db_session: Session) -> None:
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
if tool_id not in built_in_ids:
db_session.delete(tool)
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
logger.notice(f"Removed action no longer in built-in list: {tool.name}")
db_session.commit()
logger.notice("All built-in tools are loaded/verified.")

43
backend/onyx/utils/url.py Normal file
View File

@@ -0,0 +1,43 @@
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
def add_url_params(url: str, params: dict) -> str:
"""
Add parameters to a URL, handling existing parameters properly.
Args:
url: The original URL
params: Dictionary of parameters to add
Returns:
URL with added parameters
"""
# Parse the URL
parsed_url = urlparse(url)
# Get existing query parameters
query_params = parse_qs(parsed_url.query)
# Update with new parameters
for key, value in params.items():
query_params[key] = [value]
# Build the new query string
new_query = urlencode(query_params, doseq=True)
# Reconstruct the URL with the new query string
new_url = urlunparse(
(
parsed_url.scheme,
parsed_url.netloc,
parsed_url.path,
parsed_url.params,
new_query,
parsed_url.fragment,
)
)
return new_url

View File

@@ -1,4 +1,4 @@
black==23.3.0
black==23.7.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0
cohere==5.6.1

View File

@@ -54,6 +54,7 @@ class OnyxRedisCommand(Enum):
purge_vespa_syncing = "purge_vespa_syncing"
get_user_token = "get_user_token"
delete_user_token = "delete_user_token"
add_invited_user = "add_invited_user"
def __str__(self) -> str:
return self.value
@@ -163,6 +164,21 @@ def onyx_redis(
return 0
else:
return 2
elif command == OnyxRedisCommand.add_invited_user:
if not user_email:
logger.error("You must specify --user-email with add_invited_user")
return 1
current_invited_users = get_invited_users()
if user_email not in current_invited_users:
current_invited_users.append(user_email)
if dry_run:
logger.info(f"(DRY-RUN) Would add {user_email} to invited users")
else:
write_invited_users(current_invited_users)
logger.info(f"Added {user_email} to invited users")
else:
logger.info(f"{user_email} is already in the invited users list")
return 0
else:
pass
@@ -441,23 +457,6 @@ if __name__ == "__main__":
if args.tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id)
if args.command == "add_invited_user":
if not args.user_email:
print("Error: --user-email is required for add_invited_user command")
sys.exit(1)
current_invited_users = get_invited_users()
if args.user_email not in current_invited_users:
current_invited_users.append(args.user_email)
if args.dry_run:
print(f"(DRY-RUN) Would add {args.user_email} to invited users")
else:
write_invited_users(current_invited_users)
print(f"Added {args.user_email} to invited users")
else:
print(f"{args.user_email} is already in the invited users list")
sys.exit(0)
exitcode = onyx_redis(
command=args.command,
batch=args.batch,

View File

@@ -36,6 +36,7 @@ def confluence_connector() -> ConfluenceConnector:
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:

View File

@@ -28,6 +28,7 @@ def confluence_connector() -> ConfluenceConnector:
# This should never fail because even if the docs in the cloud change,
# the full doc ids retrieved should always be a subset of the slim doc ids
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_permissions(
confluence_connector: ConfluenceConnector,
) -> None:

View File

@@ -20,29 +20,32 @@ def gitbook_connector() -> GitbookConnector:
return connector
NUM_PAGES = 3
def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
doc_batch_generator = gitbook_connector.load_from_state()
# Get first batch of documents
doc_batch = next(doc_batch_generator)
assert len(doc_batch) > 0
assert len(doc_batch) == NUM_PAGES
# Verify first document structure
doc = doc_batch[0]
main_doc = doc_batch[0]
# Basic document properties
assert doc.id.startswith("gitbook-")
assert doc.semantic_identifier == "Acme Corp Internal Handbook"
assert doc.source == DocumentSource.GITBOOK
assert main_doc.id.startswith("gitbook-")
assert main_doc.semantic_identifier == "Acme Corp Internal Handbook"
assert main_doc.source == DocumentSource.GITBOOK
# Metadata checks
assert "path" in doc.metadata
assert "type" in doc.metadata
assert "kind" in doc.metadata
assert "path" in main_doc.metadata
assert "type" in main_doc.metadata
assert "kind" in main_doc.metadata
# Section checks
assert len(doc.sections) == 1
section = doc.sections[0]
assert len(main_doc.sections) == 1
section = main_doc.sections[0]
# Content specific checks
content = section.text
@@ -74,8 +77,23 @@ def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
assert section.link # Should have a URL
nested1 = doc_batch[1]
assert nested1.id.startswith("gitbook-")
assert nested1.semantic_identifier == "Nested1"
assert len(nested1.sections) == 1
# extra newlines at the end, remove them to make test easier
assert nested1.sections[0].text.strip() == "nested1"
assert nested1.source == DocumentSource.GITBOOK
nested2 = doc_batch[2]
assert nested2.id.startswith("gitbook-")
assert nested2.semantic_identifier == "Nested2"
assert len(nested2.sections) == 1
assert nested2.sections[0].text.strip() == "nested2"
assert nested2.source == DocumentSource.GITBOOK
# Time-based polling test
current_time = time.time()
poll_docs = gitbook_connector.poll_source(0, current_time)
poll_batch = next(poll_docs)
assert len(poll_batch) > 0
assert len(poll_batch) == NUM_PAGES

View File

@@ -0,0 +1,128 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.notion.connector import NotionConnector
@pytest.fixture
def notion_connector() -> NotionConnector:
"""Create a NotionConnector with credentials from environment variables"""
connector = NotionConnector()
connector.load_credentials(
{
"notion_integration_token": os.environ["NOTION_INTEGRATION_TOKEN"],
}
)
return connector
def test_notion_connector_basic(notion_connector: NotionConnector) -> None:
"""Test the NotionConnector with a real Notion page.
Uses a Notion workspace under the onyx-test.com domain.
"""
doc_batch_generator = notion_connector.poll_source(0, time.time())
# Get first batch of documents
doc_batch = next(doc_batch_generator)
assert (
len(doc_batch) == 5
), "Expected exactly 5 documents (root, two children, table entry, and table entry child)"
# Find root and child documents by semantic identifier
root_doc = None
child1_doc = None
child2_doc = None
table_entry_doc = None
table_entry_child_doc = None
for doc in doc_batch:
if doc.semantic_identifier == "Root":
root_doc = doc
elif doc.semantic_identifier == "Child1":
child1_doc = doc
elif doc.semantic_identifier == "Child2":
child2_doc = doc
elif doc.semantic_identifier == "table-entry01":
table_entry_doc = doc
elif doc.semantic_identifier == "Child-table-entry01":
table_entry_child_doc = doc
assert root_doc is not None, "Root document not found"
assert child1_doc is not None, "Child1 document not found"
assert child2_doc is not None, "Child2 document not found"
assert table_entry_doc is not None, "Table entry document not found"
assert table_entry_child_doc is not None, "Table entry child document not found"
# Verify root document structure
assert root_doc.id is not None
assert root_doc.source == DocumentSource.NOTION
# Section checks for root
assert len(root_doc.sections) == 1
root_section = root_doc.sections[0]
# Content specific checks for root
assert root_section.text == "\nroot"
assert root_section.link is not None
assert root_section.link.startswith("https://www.notion.so/")
# Verify child1 document structure
assert child1_doc.id is not None
assert child1_doc.source == DocumentSource.NOTION
# Section checks for child1
assert len(child1_doc.sections) == 1
child1_section = child1_doc.sections[0]
# Content specific checks for child1
assert child1_section.text == "\nchild1"
assert child1_section.link is not None
assert child1_section.link.startswith("https://www.notion.so/")
# Verify child2 document structure (includes database)
assert child2_doc.id is not None
assert child2_doc.source == DocumentSource.NOTION
# Section checks for child2
assert len(child2_doc.sections) == 2 # One for content, one for database
child2_section = child2_doc.sections[0]
child2_db_section = child2_doc.sections[1]
# Content specific checks for child2
assert child2_section.text == "\nchild2"
assert child2_section.link is not None
assert child2_section.link.startswith("https://www.notion.so/")
# Database section checks for child2
assert child2_db_section.text.strip() != "" # Should contain some database content
assert child2_db_section.link is not None
assert child2_db_section.link.startswith("https://www.notion.so/")
# Verify table entry document structure
assert table_entry_doc.id is not None
assert table_entry_doc.source == DocumentSource.NOTION
# Section checks for table entry
assert len(table_entry_doc.sections) == 1
table_entry_section = table_entry_doc.sections[0]
# Content specific checks for table entry
assert table_entry_section.text == "\ntable-entry01"
assert table_entry_section.link is not None
assert table_entry_section.link.startswith("https://www.notion.so/")
# Verify table entry child document structure
assert table_entry_child_doc.id is not None
assert table_entry_child_doc.source == DocumentSource.NOTION
# Section checks for table entry child
assert len(table_entry_child_doc.sections) == 1
table_entry_child_section = table_entry_child_doc.sections[0]
# Content specific checks for table entry child
assert table_entry_child_section.text == "\nchild-table-entry01"
assert table_entry_child_section.link is not None
assert table_entry_child_section.link.startswith("https://www.notion.so/")

View File

@@ -2,7 +2,7 @@ FROM python:3.11.7-slim-bookworm
WORKDIR /app
RUN pip install fastapi uvicorn
RUN pip install "pydantic-core>=2.28.0" fastapi uvicorn
COPY ./main.py /app/main.py

View File

@@ -1,3 +1,8 @@
from typing import Any
import pytest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
@@ -17,3 +22,58 @@ def test_send_message_simple_with_history(reset: None) -> None:
)
assert len(response.full_message) > 0
@pytest.mark.skip(
reason="enable for autorun when we have a testing environment with semantically useful data"
)
def test_send_message_simple_with_history_buffered() -> None:
import requests
API_KEY = "" # fill in for this to work
headers = {}
headers["Authorization"] = f"Bearer {API_KEY}"
req: dict[str, Any] = {}
req["persona_id"] = 0
req["description"] = "test_send_message_simple_with_history_buffered"
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session", headers=headers, json=req
)
chat_session_id = response.json()["chat_session_id"]
req = {}
req["chat_session_id"] = chat_session_id
req["message"] = "What does onyx do?"
req["use_agentic_search"] = True
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-api", headers=headers, json=req
)
r_json = response.json()
# all of these should exist and be greater than length 1
assert len(r_json.get("answer", "")) > 0
assert len(r_json.get("agent_sub_questions", "")) > 0
assert len(r_json.get("agent_answers")) > 0
assert len(r_json.get("agent_sub_queries")) > 0
assert "agent_refined_answer_improvement" in r_json
# top level answer should match the one we select out of agent_answers
answer_level = 0
agent_level_answer = ""
agent_refined_answer_improvement = r_json.get("agent_refined_answer_improvement")
if agent_refined_answer_improvement:
answer_level = len(r_json["agent_answers"]) - 1
answers = r_json["agent_answers"][str(answer_level)]
for answer in answers:
if answer["answer_type"] == "agent_level_answer":
agent_level_answer = answer["answer"]
break
assert r_json["answer"] == agent_level_answer
assert response.status_code == 200

View File

@@ -2,14 +2,7 @@
import { useState, useEffect, useCallback } from "react";
import { useRouter } from "next/navigation";
import {
Formik,
Form,
Field,
ErrorMessage,
FieldArray,
ArrayHelpers,
} from "formik";
import { Formik, Form, Field, ErrorMessage, FieldArray } from "formik";
import * as Yup from "yup";
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
@@ -49,7 +42,7 @@ function prettifyDefinition(definition: any) {
return JSON.stringify(definition, null, 2);
}
function ToolForm({
function ActionForm({
existingTool,
values,
setFieldValue,
@@ -118,7 +111,7 @@ function ToolForm({
<TextFormField
name="definition"
label="Definition"
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this tool."
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this action."
placeholder="Enter your OpenAPI schema here"
isTextArea={true}
defaultHeight="h-96"
@@ -185,7 +178,7 @@ function ToolForm({
clipRule="evenodd"
/>
</svg>
Learn more about tool calling in our documentation
Learn more about actions in our documentation
</Link>
</div>
@@ -229,7 +222,7 @@ function ToolForm({
Custom Headers
</h3>
<p className="text-sm mb-6 text-text-600 italic">
Specify custom headers for each request to this tool&apos;s API.
Specify custom headers for each request to this action&apos;s API.
</p>
<FieldArray
name="customHeaders"
@@ -360,7 +353,7 @@ function ToolForm({
type="submit"
disabled={isSubmitting || !!definitionError}
>
{existingTool ? "Update Tool" : "Create Tool"}
{existingTool ? "Update Action" : "Create Action"}
</Button>
</div>
</Form>
@@ -386,7 +379,7 @@ const ToolSchema = Yup.object().shape({
passthrough_auth: Yup.boolean().default(false),
});
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const [definitionError, setDefinitionError] = useState<string | null>(null);
@@ -432,7 +425,7 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
try {
definition = parseJsonWithTrailingCommas(values.definition);
} catch (error) {
setDefinitionError("Invalid JSON in tool definition");
setDefinitionError("Invalid JSON in action definition");
return;
}
@@ -453,17 +446,17 @@ export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
}
if (response.error) {
setPopup({
message: "Failed to create tool - " + response.error,
message: "Failed to create action - " + response.error,
type: "error",
});
return;
}
router.push(`/admin/tools?u=${Date.now()}`);
router.push(`/admin/actions?u=${Date.now()}`);
}}
>
{({ isSubmitting, values, setFieldValue }) => {
return (
<ToolForm
<ActionForm
existingTool={tool}
values={values}
setFieldValue={setFieldValue}

View File

@@ -15,7 +15,7 @@ import { TrashIcon } from "@/components/icons/icons";
import { deleteCustomTool } from "@/lib/tools/edit";
import { TableHeader } from "@/components/ui/table";
export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) {
export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
const router = useRouter();
const { popup, setPopup } = usePopup();

View File

@@ -2,7 +2,7 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import CardSection from "@/components/admin/CardSection";
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
import { ActionEditor } from "@/app/admin/actions/ActionEditor";
import { fetchToolByIdSS } from "@/lib/tools/fetchTools";
import { DeleteToolButton } from "./DeleteToolButton";
import { AdminPageTitle } from "@/components/admin/Title";
@@ -31,7 +31,7 @@ export default async function Page(props: {
<div>
<div>
<CardSection>
<ToolEditor tool={tool} />
<ActionEditor tool={tool} />
</CardSection>
<Title className="mt-12">Delete Tool</Title>

View File

@@ -1,6 +1,6 @@
"use client";
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
import { ActionEditor } from "@/app/admin/actions/ActionEditor";
import { BackButton } from "@/components/BackButton";
import { AdminPageTitle } from "@/components/admin/Title";
import { ToolIcon } from "@/components/icons/icons";
@@ -12,12 +12,12 @@ export default function NewToolPage() {
<BackButton />
<AdminPageTitle
title="Create Tool"
title="Create Action"
icon={<ToolIcon size={32} className="my-auto" />}
/>
<CardSection>
<ToolEditor />
<ActionEditor />
</CardSection>
</div>
);

View File

@@ -1,4 +1,4 @@
import { ToolsTable } from "./ToolsTable";
import { ActionsTable } from "./ActionTable";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { FiPlusSquare } from "react-icons/fi";
import Link from "next/link";
@@ -29,23 +29,23 @@ export default async function Page() {
<div className="mx-auto container">
<AdminPageTitle
icon={<ToolIcon size={32} className="my-auto" />}
title="Tools"
title="Actions"
/>
<Text className="mb-2">
Tools allow assistants to retrieve information or take actions.
Actions allow assistants to retrieve information or take actions.
</Text>
<div>
<Separator />
<Title>Create a Tool</Title>
<CreateButton href="/admin/tools/new" text="New Tool" />
<Title>Create an Action</Title>
<CreateButton href="/admin/actions/new" text="New Action" />
<Separator />
<Title>Existing Tools</Title>
<ToolsTable tools={tools} />
<Title>Existing Actions</Title>
<ActionsTable tools={tools} />
</div>
</div>
);

View File

@@ -1095,8 +1095,7 @@ export function AssistantEditor({
{values.is_public ? (
<p className="text-sm text-text-dark">
Anyone from your organization can view and use this
assistant
Anyone from your team can view and use this assistant
</p>
) : (
<>

View File

@@ -177,6 +177,11 @@ export function PersonasTable() {
entityName={personaToToggleDefault.name}
onClose={closeDefaultModal}
onSubmit={handleToggleDefault}
actionText={
personaToToggleDefault.is_default_persona
? "remove the featured status of"
: "set as featured"
}
actionButtonText={
personaToToggleDefault.is_default_persona
? "Remove Featured"

View File

@@ -121,7 +121,7 @@ function Main() {
);
}
function Page() {
export default function Page() {
return (
<div className="mx-auto container">
<AdminPageTitle
@@ -132,5 +132,3 @@ function Page() {
</div>
);
}
export default Page;

View File

@@ -114,8 +114,8 @@ function Main() {
<ul className="list-disc mt-2 ml-4 mb-2">
<li>
<Text>
Set a global rate limit to control your organization&apos;s overall
token spend.
Set a global rate limit to control your team&apos;s overall token
spend.
</Text>
</li>
{isPaidEnterpriseFeaturesEnabled && (

View File

@@ -21,7 +21,8 @@ import { InvitedUserSnapshot } from "@/lib/types";
import { SearchBar } from "@/components/search/SearchBar";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import PendingUsersTable from "@/components/admin/users/PendingUsersTable";
import { useUser } from "@/components/user/UserProvider";
const UsersTables = ({
q,
setPopup,
@@ -44,6 +45,15 @@ const UsersTables = ({
errorHandlingFetcher
);
const {
data: pendingUsers,
error: pendingUsersError,
isLoading: pendingUsersLoading,
mutate: pendingUsersMutate,
} = useSWR<InvitedUserSnapshot[]>(
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
errorHandlingFetcher
);
// Show loading animation only during the initial data fetch
if (!validDomains) {
return <ThreeDotsLoader />;
@@ -63,6 +73,9 @@ const UsersTables = ({
<TabsList>
<TabsTrigger value="current">Current Users</TabsTrigger>
<TabsTrigger value="invited">Invited Users</TabsTrigger>
{NEXT_PUBLIC_CLOUD_ENABLED && (
<TabsTrigger value="pending">Pending Users</TabsTrigger>
)}
</TabsList>
<TabsContent value="current">
@@ -97,6 +110,25 @@ const UsersTables = ({
</CardContent>
</Card>
</TabsContent>
{NEXT_PUBLIC_CLOUD_ENABLED && (
<TabsContent value="pending">
<Card>
<CardHeader>
<CardTitle>Pending Users</CardTitle>
</CardHeader>
<CardContent>
<PendingUsersTable
users={pendingUsers || []}
setPopup={setPopup}
mutate={pendingUsersMutate}
error={pendingUsersError}
isLoading={pendingUsersLoading}
q={q}
/>
</CardContent>
</Card>
</TabsContent>
)}
</Tabs>
);
};
@@ -190,7 +222,7 @@ const AddUserButton = ({
entityName="your Access Logic"
onClose={() => setShowConfirmation(false)}
onSubmit={handleConfirmFirstInvite}
additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your instance."
additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your team."
actionButtonText="Continue"
variant="action"
/>

View File

@@ -18,8 +18,8 @@ const Page = () => {
need to either:
</p>
<ul className="list-disc text-left text-text-600 w-full pl-6 mx-auto">
<li>Be invited to an existing Onyx organization</li>
<li>Create a new Onyx organization</li>
<li>Be invited to an existing Onyx team</li>
<li>Create a new Onyx team</li>
</ul>
<div className="flex justify-center">
<Link

View File

@@ -0,0 +1,108 @@
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { User } from "@/lib/types";
import {
getCurrentUserSS,
getAuthTypeMetadataSS,
AuthTypeMetadata,
getAuthUrlSS,
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
import Text from "@/components/ui/text";
import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import AuthErrorDisplay from "@/components/auth/AuthErrorDisplay";
const Page = async (props: {
searchParams?: Promise<{ [key: string]: string | string[] | undefined }>;
}) => {
const searchParams = await props.searchParams;
const nextUrl = Array.isArray(searchParams?.next)
? searchParams?.next[0]
: searchParams?.next || null;
const defaultEmail = Array.isArray(searchParams?.email)
? searchParams?.email[0]
: searchParams?.email || null;
const teamName = Array.isArray(searchParams?.team)
? searchParams?.team[0]
: searchParams?.team || "your team";
// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
// will not render
let authTypeMetadata: AuthTypeMetadata | null = null;
let currentUser: User | null = null;
try {
[authTypeMetadata, currentUser] = await Promise.all([
getAuthTypeMetadataSS(),
getCurrentUserSS(),
]);
} catch (e) {
console.log(`Some fetch failed for the login page - ${e}`);
}
// simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") {
return redirect("/chat");
}
// if user is already logged in, take them to the main app page
if (currentUser && currentUser.is_active && !currentUser.is_anonymous_user) {
if (!authTypeMetadata?.requiresVerification || currentUser.is_verified) {
return redirect("/chat");
}
return redirect("/auth/waiting-on-verification");
}
const cloud = authTypeMetadata?.authType === "cloud";
// only enable this page if basic login is enabled
if (authTypeMetadata?.authType !== "basic" && !cloud) {
return redirect("/chat");
}
let authUrl: string | null = null;
if (cloud && authTypeMetadata) {
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
}
const emailDomain = defaultEmail?.split("@")[1];
return (
<AuthFlowContainer authState="join">
<HealthCheckBanner />
<AuthErrorDisplay searchParams={searchParams} />
<>
<div className="absolute top-10x w-full"></div>
<div className="flex w-full flex-col justify-center">
<h2 className="text-center text-xl text-strong font-bold">
Re-authenticate to join team
</h2>
{cloud && authUrl && (
<div className="w-full justify-center">
<SignInButton authorizeUrl={authUrl} authType="cloud" />
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-background-300"></div>
<span className="px-4 text-text-500">or</span>
<div className="flex-grow border-t border-background-300"></div>
</div>
</div>
)}
<EmailPasswordForm
isSignup
isJoin
shouldVerify={authTypeMetadata?.requiresVerification}
nextUrl={nextUrl}
defaultEmail={defaultEmail}
/>
</div>
</>
</AuthFlowContainer>
);
};
export default Page;

View File

@@ -13,6 +13,7 @@ import { set } from "lodash";
import { NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED } from "@/lib/constants";
import Link from "next/link";
import { useUser } from "@/components/user/UserProvider";
import { useRouter } from "next/navigation";
export function EmailPasswordForm({
isSignup = false,
@@ -20,15 +21,18 @@ export function EmailPasswordForm({
referralSource,
nextUrl,
defaultEmail,
isJoin = false,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
nextUrl?: string | null;
defaultEmail?: string | null;
isJoin?: boolean;
}) {
const { user } = useUser();
const { popup, setPopup } = usePopup();
const router = useRouter();
const [isWorking, setIsWorking] = useState(false);
return (
<>
@@ -79,6 +83,11 @@ export function EmailPasswordForm({
});
setIsWorking(false);
return;
} else {
setPopup({
type: "success",
message: "Account created successfully. Please log in.",
});
}
}
@@ -92,7 +101,9 @@ export function EmailPasswordForm({
window.location.href = "/auth/waiting-on-verification";
} else {
// See above comment
window.location.href = nextUrl ? encodeURI(nextUrl) : "/";
window.location.href = nextUrl
? encodeURI(nextUrl)
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;
}
} else {
setIsWorking(false);
@@ -135,11 +146,12 @@ export function EmailPasswordForm({
/>
<Button
variant="agent"
type="submit"
disabled={isSubmitting}
className="mx-auto !py-4 w-full"
>
{isSignup ? "Sign Up" : "Log In"}
{isJoin ? "Join" : isSignup ? "Sign Up" : "Log In"}
</Button>
{user?.is_anonymous_user && (
<Link

View File

@@ -51,25 +51,16 @@ export default function LoginPage({
</div>
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
<div className="flex mt-4 justify-between">
<Link
href={`/auth/signup${
searchParams?.next ? `?next=${searchParams.next}` : ""
}`}
className="text-link font-medium"
>
Create an account
</Link>
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
<div className="flex mt-4 justify-between">
<Link
href="/auth/forgot-password"
className="text-link font-medium"
>
Reset Password
</Link>
)}
</div>
</div>
)}
</div>
)}

View File

@@ -46,7 +46,7 @@ export function SignInButton({
return (
<a
className="mx-auto mb-4 mt-6 py-3 w-full text-neutral-100 bg-indigo-500 flex rounded cursor-pointer hover:bg-indigo-800"
className="mx-auto mb-4 mt-6 py-3 w-full dark:text-neutral-300 text-neutral-600 border border-neutral-300 flex rounded cursor-pointer hover:border-neutral-400 transition-colors"
href={finalAuthorizeUrl}
>
{button}

View File

@@ -215,11 +215,7 @@ export function ChatPage({
const isInitialLoad = useRef(true);
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
const {
assistants: availableAssistants,
finalAssistants,
pinnedAssistants,
} = useAssistants();
const { assistants: availableAssistants, pinnedAssistants } = useAssistants();
const [showApiKeyModal, setShowApiKeyModal] = useState(
!shouldShowWelcomeModal
@@ -229,7 +225,7 @@ export function ChatPage({
const slackChatId = searchParams.get("slackChatId");
const existingChatIdRaw = searchParams.get("chatId");
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
const [showHistorySidebar, setShowHistorySidebar] = useState(false);
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
@@ -2451,7 +2447,7 @@ export function ChatPage({
h-full
${sidebarVisible ? "w-[200px]" : "w-[0px]"}
`}
></div>
/>
)}
</div>
)}

View File

@@ -117,9 +117,8 @@ export function ShareChatSessionModal({
{shareLink ? (
<div>
<Text>
This chat session is currently shared. Anyone in your
organization can view the message history using the following
link:
This chat session is currently shared. Anyone in your team can
view the message history using the following link:
</Text>
<div className="flex mt-2">
@@ -160,7 +159,7 @@ export function ShareChatSessionModal({
<div>
<Callout type="warning" title="Warning" className="mb-4">
Please make sure that all content in this chat is safe to
share with the whole organization.
share with the whole team.
</Callout>
<div className="flex w-full justify-between">
<Button

View File

@@ -12,7 +12,7 @@ function Main() {
<div className="mt-4">
<Callout type="danger" title="Custom Analytics is not enabled.">
To set up custom analytics scripts, please work with the team who
setup Onyx in your organization to set the{" "}
setup Onyx in your team to set the{" "}
<i>CUSTOM_ANALYTICS_SECRET_KEY</i> environment variable.
</Callout>
</div>

View File

@@ -140,7 +140,7 @@ export function WhitelabelingForm() {
<TextFormField
label="Application Name"
name="application_name"
subtext={`The custom name you are giving Onyx for your organization. This will replace 'Onyx' everywhere in the UI.`}
subtext={`The custom name you are giving Onyx for your team. This will replace 'Onyx' everywhere in the UI.`}
placeholder="Custom name which will replace 'Onyx'"
disabled={isSubmitting}
/>

View File

@@ -202,10 +202,10 @@ export function ClientLayout({
className="text-text-700"
size={18}
/>
<div className="ml-1">Tools</div>
<div className="ml-1">Actions</div>
</div>
),
link: "/admin/tools",
link: "/admin/actions",
},
]
: []),

View File

@@ -2,19 +2,8 @@
"use client";
import React, { useContext } from "react";
import Link from "next/link";
import { Logo } from "@/components/logo/Logo";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { WarningCircle, WarningDiamond } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { CgArrowsExpandUpLeft } from "react-icons/cg";
import LogoWithText from "@/components/header/LogoWithText";
import { LogoComponent } from "@/components/logo/FixedLogo";
interface Item {

View File

@@ -0,0 +1,154 @@
import { useState } from "react";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import {
Table,
TableHead,
TableRow,
TableBody,
TableCell,
} from "@/components/ui/table";
import CenteredPageSelector from "./CenteredPageSelector";
import { ThreeDotsLoader } from "@/components/Loading";
import { InvitedUserSnapshot } from "@/lib/types";
import { TableHeader } from "@/components/ui/table";
import { Button } from "@/components/ui/button";
import { ErrorCallout } from "@/components/ErrorCallout";
import { FetchError } from "@/lib/fetcher";
import { CheckIcon } from "lucide-react";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
const USERS_PER_PAGE = 10;
interface Props {
users: InvitedUserSnapshot[];
setPopup: (spec: PopupSpec) => void;
mutate: () => void;
error: FetchError | null;
isLoading: boolean;
q: string;
}
const PendingUsersTable = ({
users,
setPopup,
mutate,
error,
isLoading,
q,
}: Props) => {
const [currentPageNum, setCurrentPageNum] = useState<number>(1);
const [userToApprove, setUserToApprove] = useState<string | null>(null);
if (!users.length)
return <p>Users that have requested to join will show up here</p>;
const totalPages = Math.ceil(users.length / USERS_PER_PAGE);
// Filter users based on the search query
const filteredUsers = q
? users.filter((user) => user.email.includes(q))
: users;
// Get the current page of users
const currentPageOfUsers = filteredUsers.slice(
(currentPageNum - 1) * USERS_PER_PAGE,
currentPageNum * USERS_PER_PAGE
);
if (isLoading) {
return <ThreeDotsLoader />;
}
if (error) {
return (
<ErrorCallout
errorTitle="Error loading pending users"
errorMsg={error?.info?.detail}
/>
);
}
const handleAcceptRequest = async (email: string) => {
try {
await fetch("/api/tenants/users/invite/approve", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email }),
});
mutate();
setUserToApprove(null);
} catch (error) {
setPopup({
type: "error",
message: "Failed to approve user request",
});
}
};
return (
<>
{userToApprove && (
<ConfirmEntityModal
entityType="Join Request"
entityName={userToApprove}
onClose={() => setUserToApprove(null)}
onSubmit={() => handleAcceptRequest(userToApprove)}
actionButtonText="Approve"
actionText="approve the join request of"
additionalDetails={`${userToApprove} has requested to join the team. Approving will add them as a user in this team.`}
variant="action"
accent
removeConfirmationText
/>
)}
<Table className="overflow-visible">
<TableHeader>
<TableRow>
<TableHead>Email</TableHead>
<TableHead>
<div className="flex justify-end">Actions</div>
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{currentPageOfUsers.length ? (
currentPageOfUsers.map((user) => (
<TableRow key={user.email}>
<TableCell>{user.email}</TableCell>
<TableCell>
<div className="flex justify-end">
<Button
variant="outline"
size="sm"
onClick={() => setUserToApprove(user.email)}
>
<CheckIcon className="h-4 w-4" />
Accept Join Request
</Button>
</div>
</TableCell>
</TableRow>
))
) : (
<TableRow>
<TableCell colSpan={2} className="h-24 text-center">
{`No pending users found matching "${q}"`}
</TableCell>
</TableRow>
)}
</TableBody>
</Table>
{totalPages > 1 ? (
<CenteredPageSelector
currentPage={currentPageNum}
totalPages={totalPages}
onPageChange={setCurrentPageNum}
/>
) : null}
</>
);
};
export default PendingUsersTable;

View File

@@ -22,19 +22,19 @@ export const LeaveOrganizationButton = ({
}) => {
const router = useRouter();
const { trigger, isMutating } = useSWRMutation(
"/api/tenants/leave-organization",
"/api/tenants/leave-team",
userMutationFetcher,
{
onSuccess: () => {
mutate();
setPopup({
message: "Successfully left the organization!",
message: "Successfully left the team!",
type: "success",
});
},
onError: (errorMsg) =>
setPopup({
message: `Unable to leave organization - ${errorMsg}`,
message: `Unable to leave team - ${errorMsg}`,
type: "error",
}),
}
@@ -53,11 +53,11 @@ export const LeaveOrganizationButton = ({
<ConfirmEntityModal
variant="action"
actionButtonText="Leave"
entityType="organization"
entityName="your organization"
entityType="team"
entityName="your team"
onClose={() => setShowLeaveModal(false)}
onSubmit={handleLeaveOrganization}
additionalDetails="You will lose access to all organization data and resources."
additionalDetails="You will lose access to all team data and resources."
/>
)}

View File

@@ -4,7 +4,7 @@ import { useEffect } from "react";
import { usePopup } from "../admin/connectors/Popup";
const ERROR_MESSAGES = {
Anonymous: "Your organization does not have anonymous access enabled.",
Anonymous: "Your team does not have anonymous access enabled.",
};
export default function AuthErrorDisplay({

View File

@@ -6,7 +6,7 @@ export default function AuthFlowContainer({
authState,
}: {
children: React.ReactNode;
authState?: "signup" | "login";
authState?: "signup" | "login" | "join";
}) {
return (
<div className="p-4 flex flex-col items-center justify-center min-h-screen bg-background">

View File

@@ -6,6 +6,8 @@ import { SettingsProvider } from "../settings/SettingsProvider";
import { AssistantsProvider } from "./AssistantsContext";
import { Persona } from "@/app/admin/assistants/interfaces";
import { User } from "@/lib/types";
import { ModalProvider } from "./ModalContext";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
interface AppProviderProps {
children: React.ReactNode;
@@ -16,6 +18,8 @@ interface AppProviderProps {
hasImageCompatibleModel: boolean;
}
//
export const AppProvider = ({
children,
user,
@@ -33,7 +37,7 @@ export const AppProvider = ({
hasAnyConnectors={hasAnyConnectors}
hasImageCompatibleModel={hasImageCompatibleModel}
>
{children}
<ModalProvider user={user}>{children}</ModalProvider>
</AssistantsProvider>
</ProviderContextProvider>
</UserProvider>

View File

@@ -0,0 +1,95 @@
"use client";
import React, { createContext, useContext, useState, useCallback } from "react";
import { NewTeamModal } from "../modals/NewTeamModal";
import NewTenantModal from "../modals/NewTenantModal";
import { User, NewTenantInfo } from "@/lib/types";
type ModalContextType = {
showNewTeamModal: boolean;
setShowNewTeamModal: (show: boolean) => void;
newTenantInfo: NewTenantInfo | null;
setNewTenantInfo: (info: NewTenantInfo | null) => void;
invitationInfo: NewTenantInfo | null;
setInvitationInfo: (info: NewTenantInfo | null) => void;
};
const ModalContext = createContext<ModalContextType | undefined>(undefined);
export const useModalContext = () => {
const context = useContext(ModalContext);
if (context === undefined) {
throw new Error("useModalContext must be used within a ModalProvider");
}
return context;
};
export const ModalProvider: React.FC<{
children: React.ReactNode;
user: User | null;
}> = ({ children, user }) => {
const [showNewTeamModal, setShowNewTeamModal] = useState(false);
const [newTenantInfo, setNewTenantInfo] = useState<NewTenantInfo | null>(
user?.tenant_info?.new_tenant || null
);
const [invitationInfo, setInvitationInfo] = useState<NewTenantInfo | null>(
user?.tenant_info?.invitation || null
);
// Initialize modal states based on user info
React.useEffect(() => {
if (user?.tenant_info?.new_tenant) {
setNewTenantInfo(user.tenant_info.new_tenant);
}
if (user?.tenant_info?.invitation) {
setInvitationInfo(user.tenant_info.invitation);
}
}, [user?.tenant_info]);
// Render all application-wide modals
const renderModals = () => {
if (!user) return null;
return (
<>
{/* Modal for users to request to join an existing team */}
<NewTeamModal />
{/* Modal for users who've been accepted to a new team */}
{newTenantInfo && (
<NewTenantModal
tenantInfo={newTenantInfo}
// Close function to clear the modal state
onClose={() => setNewTenantInfo(null)}
/>
)}
{/* Modal for users who've been invited to join a team */}
{invitationInfo && (
<NewTenantModal
isInvite={true}
tenantInfo={invitationInfo}
// Close function to clear the modal state
onClose={() => setInvitationInfo(null)}
/>
)}
</>
);
};
return (
<ModalContext.Provider
value={{
showNewTeamModal,
setShowNewTeamModal,
newTenantInfo,
setNewTenantInfo,
invitationInfo,
setInvitationInfo,
}}
>
{children}
{renderModals()}
</ModalContext.Provider>
);
};

View File

@@ -197,23 +197,21 @@ export default function ModifyCredential({
Are you sure you want to delete this credential? You cannot delete
credentials that are linked to live connectors.
</p>
<div className="mt-6 flex justify-between">
<button
className="rounded py-1.5 px-2 bg-background-800 text-text-200"
<div className="mt-6 flex gap-x-2 justify-end">
<Button
onClick={async () => {
await onDeleteCredential(confirmDeletionCredential);
setConfirmDeletionCredential(null);
}}
>
Yes
</button>
<button
Confirm
</Button>
<Button
variant="outline"
onClick={() => setConfirmDeletionCredential(null)}
className="rounded py-1.5 px-2 bg-background-150 text-text-800"
>
{" "}
No
</button>
Cancel
</Button>
</div>
</>
</Modal>

View File

@@ -8,8 +8,11 @@ export const ConfirmEntityModal = ({
entityName,
additionalDetails,
actionButtonText,
actionText,
includeCancelButton = true,
variant = "delete",
accent = false,
removeConfirmationText = false,
}: {
entityType: string;
entityName: string;
@@ -17,23 +20,21 @@ export const ConfirmEntityModal = ({
onSubmit: () => void;
additionalDetails?: string;
actionButtonText?: string;
actionText?: string;
includeCancelButton?: boolean;
variant?: "delete" | "action";
accent?: boolean;
removeConfirmationText?: boolean;
}) => {
const isDeleteVariant = variant === "delete";
const defaultButtonText = isDeleteVariant ? "Delete" : "Confirm";
const buttonText = actionButtonText || defaultButtonText;
const getActionText = () => {
if (isDeleteVariant) {
return "delete";
}
switch (entityType) {
case "Default Persona":
return "change the default status of";
default:
return "modify";
if (actionText) {
return actionText;
}
return isDeleteVariant ? "delete" : "modify";
};
return (
@@ -44,9 +45,11 @@ export const ConfirmEntityModal = ({
{buttonText} {entityType}
</h2>
</div>
<p className="mb-4">
Are you sure you want to {getActionText()} <b>{entityName}</b>?
</p>
{!removeConfirmationText && (
<p className="mb-4">
Are you sure you want to {getActionText()} <b>{entityName}</b>?
</p>
)}
{additionalDetails && <p className="mb-4">{additionalDetails}</p>}
<div className="flex justify-end gap-2">
{includeCancelButton && (
@@ -56,7 +59,9 @@ export const ConfirmEntityModal = ({
)}
<Button
onClick={onSubmit}
variant={isDeleteVariant ? "destructive" : "default"}
variant={
accent ? "agent" : isDeleteVariant ? "destructive" : "default"
}
>
{buttonText}
</Button>

View File

@@ -0,0 +1,226 @@
"use client";
import { useState, useEffect } from "react";
import { useRouter, useSearchParams } from "next/navigation";
import { Dialog } from "@headlessui/react";
import { Button } from "../ui/button";
import { usePopup } from "@/components/admin/connectors/Popup";
import { Building, ArrowRight, Send, CheckCircle } from "lucide-react";
import { useUser } from "../user/UserProvider";
import { useModalContext } from "../context/ModalContext";
interface TenantByDomainResponse {
tenant_id: string;
number_of_users: number;
creator_email: string;
}
export function NewTeamModal() {
const { showNewTeamModal, setShowNewTeamModal } = useModalContext();
const [existingTenant, setExistingTenant] =
useState<TenantByDomainResponse | null>(null);
const [isLoading, setIsLoading] = useState(true);
const [isSubmitting, setIsSubmitting] = useState(false);
const [hasRequestedInvite, setHasRequestedInvite] = useState(false);
const [error, setError] = useState<string | null>(null);
const { user } = useUser();
const appDomain = user?.email.split("@")[1];
const router = useRouter();
const searchParams = useSearchParams();
const { setPopup } = usePopup();
useEffect(() => {
const hasNewTeamParam = searchParams.has("new_team");
if (hasNewTeamParam) {
setShowNewTeamModal(true);
fetchTenantInfo();
// Remove the new_team parameter from the URL without page reload
const newParams = new URLSearchParams(searchParams.toString());
newParams.delete("new_team");
const newUrl =
window.location.pathname +
(newParams.toString() ? `?${newParams.toString()}` : "");
window.history.replaceState({}, "", newUrl);
}
}, [searchParams, setShowNewTeamModal]);
const fetchTenantInfo = async () => {
setIsLoading(true);
setError(null);
try {
const response = await fetch("/api/tenants/existing-team-by-domain");
if (!response.ok) {
throw new Error(`Failed to fetch team info: ${response.status}`);
}
const responseJson = await response.json();
if (!responseJson) {
setShowNewTeamModal(false);
setExistingTenant(null);
return;
}
const data = responseJson as TenantByDomainResponse;
setExistingTenant(data);
} catch (error) {
console.error("Failed to fetch tenant info:", error);
setError("Could not retrieve team information. Please try again later.");
} finally {
setIsLoading(false);
}
};
const handleRequestInvite = async () => {
if (!existingTenant) return;
setIsSubmitting(true);
setError(null);
try {
const response = await fetch("/api/tenants/users/invite/request", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ tenant_id: existingTenant.tenant_id }),
});
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
throw new Error(errorData.message || "Failed to request invite");
}
setHasRequestedInvite(true);
setPopup({
message: "Your invite request has been sent to the team admin.",
type: "success",
});
} catch (error) {
const message =
error instanceof Error ? error.message : "Failed to request an invite";
setError(message);
setPopup({
message,
type: "error",
});
} finally {
setIsSubmitting(false);
}
};
const handleContinueToNewOrg = () => {
const newUrl = window.location.pathname;
router.replace(newUrl);
setShowNewTeamModal(false);
};
// Update the close handler to use the context
const handleClose = () => {
setShowNewTeamModal(false);
};
// Only render if showNewTeamModal is true
if (!showNewTeamModal || isLoading) return null;
return (
<Dialog
open={showNewTeamModal}
onClose={handleClose}
className="relative z-[1000]"
>
{/* Modal backdrop */}
<div className="fixed inset-0 bg-[#000]/50" aria-hidden="true" />
<div className="fixed inset-0 flex items-center justify-center p-4">
<Dialog.Panel className="mx-auto w-full max-w-md rounded-lg bg-white dark:bg-neutral-800 p-6 shadow-xl border border-neutral-200 dark:border-neutral-700">
<Dialog.Title className="text-xl font-semibold mb-4 flex items-center">
{hasRequestedInvite ? (
<>
<CheckCircle className="mr-2 h-5 w-5 text-neutral-900 dark:text-[#fff]" />
Join Request Sent
</>
) : (
<>
<Building className="mr-2 h-5 w-5" />
We found an existing team for {appDomain}
</>
)}
</Dialog.Title>
{isLoading ? (
<div className="py-8 text-center">
<div className="animate-spin rounded-full h-8 w-8 border-b-2 border-neutral-900 dark:border-neutral-100 mx-auto mb-4"></div>
<p>Loading team information...</p>
</div>
) : error ? (
<div className="space-y-4">
<p className="text-red-500 dark:text-red-400">{error}</p>
<div className="flex w-full pt-2">
<Button
variant="agent"
onClick={handleContinueToNewOrg}
className="flex w-full text-center items-center justify-center"
>
Continue with new team
<ArrowRight className="ml-2 h-4 w-4" />
</Button>
</div>
</div>
) : hasRequestedInvite ? (
<div className="space-y-4">
<p className="text-neutral-700 dark:text-neutral-200">
Your join request has been sent. You can explore as your own
team while waiting for an admin of {appDomain} to approve your
request.
</p>
<div className="flex w-full pt-2">
<Button
variant="agent"
onClick={handleContinueToNewOrg}
className="flex w-full text-center items-center justify-center"
>
Try Onyx while waiting
<ArrowRight className="ml-2 h-4 w-4" />
</Button>
</div>
</div>
) : (
<div className="space-y-4">
<p className="text-neutral-500 dark:text-neutral-200 text-sm mb-2">
Your join request can be approved by any admin of {appDomain}.
</p>
<div className="mt-4">
<Button
onClick={handleRequestInvite}
variant="agent"
className="flex w-full items-center justify-center"
disabled={isSubmitting}
>
{isSubmitting ? (
<span className="flex items-center">
<span className="animate-spin mr-2"></span>
Sending request...
</span>
) : (
<>
<Send className="mr-2 h-4 w-4" />
Request to join your team
</>
)}
</Button>
</div>
<div
onClick={handleContinueToNewOrg}
className="flex hover:underline cursor-pointer text-link text-sm flex-col space-y-3 pt-0"
>
+ Continue with new team
</div>
</div>
)}
</Dialog.Panel>
</div>
</Dialog>
);
}

View File

@@ -0,0 +1,227 @@
"use client";
import { useState } from "react";
import { Dialog } from "@headlessui/react";
import { Button } from "../ui/button";
import { usePopup } from "@/components/admin/connectors/Popup";
import { ArrowRight, X } from "lucide-react";
import { logout } from "@/lib/user";
import { useUser } from "../user/UserProvider";
import { NewTenantInfo } from "@/lib/types";
import { useRouter } from "next/navigation";
// App domain should not be hardcoded
const APP_DOMAIN = process.env.NEXT_PUBLIC_APP_DOMAIN || "onyx.app";
interface NewTenantModalProps {
tenantInfo: NewTenantInfo;
isInvite?: boolean;
onClose?: () => void;
}
export default function NewTenantModal({
tenantInfo,
isInvite = false,
onClose,
}: NewTenantModalProps) {
const router = useRouter();
const { setPopup } = usePopup();
const { user } = useUser();
const [isOpen, setIsOpen] = useState(true);
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const handleClose = () => {
setIsOpen(false);
onClose?.();
};
const handleJoinTenant = async () => {
setIsLoading(true);
setError(null);
try {
if (isInvite) {
// Accept the invitation through the API
const response = await fetch("/api/tenants/users/invite/accept", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ tenant_id: tenantInfo.tenant_id }),
});
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
throw new Error(errorData.message || "Failed to accept invitation");
}
setPopup({
message: "You have accepted the invitation.",
type: "success",
});
} else {
// For non-invite flow, just show success message
setPopup({
message: "Processing your team join request...",
type: "success",
});
}
// Common logout and redirect for both flows
await logout();
router.push(`/auth/join?email=${encodeURIComponent(user?.email || "")}`);
handleClose();
} catch (error) {
const message =
error instanceof Error
? error.message
: "Failed to join the team. Please try again.";
setError(message);
setPopup({
message,
type: "error",
});
} finally {
setIsLoading(false);
}
};
const handleRejectInvite = async () => {
if (!isInvite) return;
setIsLoading(true);
setError(null);
try {
// Deny the invitation through the API
const response = await fetch("/api/tenants/users/invite/deny", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ tenant_id: tenantInfo.tenant_id }),
});
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
throw new Error(errorData.message || "Failed to decline invitation");
}
setPopup({
message: "You have declined the invitation.",
type: "info",
});
handleClose();
} catch (error) {
const message =
error instanceof Error
? error.message
: "Failed to decline the invitation. Please try again.";
setError(message);
setPopup({
message,
type: "error",
});
} finally {
setIsLoading(false);
}
};
if (!isOpen) return null;
return (
<Dialog open={isOpen} onClose={handleClose} className="relative z-[1000]">
{/* Modal backdrop */}
<div className="fixed inset-0 bg-[#000]/50" aria-hidden="true" />
<div className="fixed inset-0 flex items-center justify-center p-4">
<Dialog.Panel className="mx-auto w-full max-w-md rounded-lg bg-white dark:bg-neutral-800 p-6 shadow-xl border border-neutral-200 dark:border-neutral-700">
<Dialog.Title className="text-xl font-semibold mb-4 flex items-center">
{isInvite ? (
<>
You have been invited to join {tenantInfo.number_of_users}
other teammate{tenantInfo.number_of_users === 1
? ""
: "s"} of {APP_DOMAIN}.
</>
) : (
<>
Your request to join {tenantInfo.number_of_users} other users of{" "}
{APP_DOMAIN} has been approved.
</>
)}
</Dialog.Title>
<div className="space-y-4">
{error && (
<p className="text-red-500 dark:text-red-400 text-sm">{error}</p>
)}
<p className="text-sm text-neutral-600 dark:text-neutral-400">
{isInvite ? (
<>
By accepting this invitation, you will join the existing{" "}
{APP_DOMAIN} team and lose access to your current team.
<br />
Note: you will lose access to your current assistants,
prompts, chats, and connected sources.
</>
) : (
<>
To finish joining your team, please reauthenticate with{" "}
<em>{user?.email}</em>.
</>
)}
</p>
<div
className={`flex ${
isInvite ? "justify-between" : "justify-center"
} w-full pt-2 gap-4`}
>
{isInvite && (
<Button
onClick={handleRejectInvite}
variant="outline"
className="flex items-center flex-1"
disabled={isLoading}
>
{isLoading ? (
<span className="animate-spin mr-2"></span>
) : (
<X className="mr-2 h-4 w-4" />
)}
Decline
</Button>
)}
<Button
variant="agent"
onClick={handleJoinTenant}
className={`flex items-center justify-center ${
isInvite ? "flex-1" : "w-full"
}`}
disabled={isLoading}
>
{isLoading ? (
<span className="flex items-center">
<span className="animate-spin mr-2"></span>
{isInvite ? "Accepting..." : "Joining..."}
</span>
) : (
<>
{isInvite ? "Accept Invitation" : "Reauthenticate"}
<ArrowRight className="ml-2 h-4 w-4" />
</>
)}
</Button>
</div>
</div>
</Dialog.Panel>
</div>
</Dialog>
);
}

View File

@@ -76,8 +76,8 @@ export function UserProvider({
const identifyData: Record<string, any> = {
email: user.email,
};
if (user.organization_name) {
identifyData.organization_name = user.organization_name;
if (user.team_name) {
identifyData.team_name = user.team_name;
}
posthog.identify(user.id, identifyData);
} else {

View File

@@ -57,7 +57,7 @@ export interface User {
current_token_expiry_length?: number;
oidc_expiry?: Date;
is_cloud_superuser?: boolean;
organization_name: string | null;
team_name: string | null;
is_anonymous_user?: boolean;
// If user does not have a configured password
// (i.e.) they are using an oauth flow
@@ -65,6 +65,17 @@ export interface User {
// we don't want to show them things like the reset password
// functionality
password_configured?: boolean;
tenant_info?: TenantInfo | null;
}
export interface TenantInfo {
new_tenant?: NewTenantInfo | null;
invitation?: NewTenantInfo | null;
}
export interface NewTenantInfo {
tenant_id: string;
number_of_users: number;
}
export interface AllUsersResponse {

View File

@@ -1,6 +1,6 @@
import { cookies } from "next/headers";
import { User } from "./types";
import { buildUrl } from "./utilsSS";
import { buildUrl, UrlBuilder } from "./utilsSS";
import { ReadonlyRequestCookies } from "next/dist/server/web/spec-extension/adapters/request-cookies";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
@@ -55,13 +55,12 @@ export const getAuthDisabledSS = async (): Promise<boolean> => {
};
const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
const res = await fetch(
buildUrl(
`/auth/oidc/authorize${
nextUrl ? `?next=${encodeURIComponent(nextUrl)}` : ""
}`
)
);
const url = UrlBuilder.fromInternalUrl("/auth/oidc/authorize");
if (nextUrl) {
url.addParam("next", nextUrl);
}
const res = await fetch(url.toString());
if (!res.ok) {
throw new Error("Failed to fetch data");
}
@@ -71,18 +70,16 @@ const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
};
const getGoogleOAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
const res = await fetch(
buildUrl(
`/auth/oauth/authorize${
nextUrl ? `?next=${encodeURIComponent(nextUrl)}` : ""
}`
),
{
headers: {
cookie: processCookies(await cookies()),
},
}
);
const url = UrlBuilder.fromInternalUrl("/auth/oauth/authorize");
if (nextUrl) {
url.addParam("next", nextUrl);
}
const res = await fetch(url.toString(), {
headers: {
cookie: processCookies(await cookies()),
},
});
if (!res.ok) {
throw new Error("Failed to fetch data");
}
@@ -92,13 +89,12 @@ const getGoogleOAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
};
const getSAMLAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
const res = await fetch(
buildUrl(
`/auth/saml/authorize${
nextUrl ? `?next=${encodeURIComponent(nextUrl)}` : ""
}`
)
);
const url = UrlBuilder.fromInternalUrl("/auth/saml/authorize");
if (nextUrl) {
url.addParam("next", nextUrl);
}
const res = await fetch(url.toString());
if (!res.ok) {
throw new Error("Failed to fetch data");
}
@@ -175,6 +171,7 @@ export const getCurrentUserSS = async (): Promise<User | null> => {
.join("; "),
},
});
if (!response.ok) {
return null;
}

View File

@@ -15,6 +15,47 @@ export function buildUrl(path: string) {
return `${INTERNAL_URL}/${path}`;
}
export class UrlBuilder {
private url: URL;
constructor(baseUrl: string) {
try {
this.url = new URL(baseUrl);
} catch (e) {
// Handle relative URLs by prepending a base
this.url = new URL(baseUrl, "http://placeholder.com");
}
}
addParam(key: string, value: string | number | boolean): UrlBuilder {
this.url.searchParams.set(key, String(value));
return this;
}
addParams(params: Record<string, string | number | boolean>): UrlBuilder {
Object.entries(params).forEach(([key, value]) => {
this.url.searchParams.set(key, String(value));
});
return this;
}
toString(): string {
// Extract just the path and query parts for relative URLs
if (this.url.origin === "http://placeholder.com") {
return `${this.url.pathname}${this.url.search}`;
}
return this.url.toString();
}
static fromInternalUrl(path: string): UrlBuilder {
return new UrlBuilder(buildUrl(path));
}
static fromClientUrl(path: string): UrlBuilder {
return new UrlBuilder(buildClientUrl(path));
}
}
export async function fetchSS(url: string, options?: RequestInit) {
const init = options || {
credentials: "include",

View File

@@ -2,7 +2,10 @@ import { test, expect } from "@chromatic-com/playwright";
import { loginAsRandomUser, loginAs } from "../utils/auth";
import { TEST_ADMIN2_CREDENTIALS, TEST_ADMIN_CREDENTIALS } from "../constants";
test("User changes password and logs in with new password", async ({
// test("User changes password and logs in with new password", async ({
// Skip this test for now
test.skip("User changes password and logs in with new password", async ({
page,
}) => {
// Clear browser context before starting the test
@@ -45,7 +48,8 @@ test("User changes password and logs in with new password", async ({
test.use({ storageState: "admin2_auth.json" });
test("Admin resets own password and logs in with new password", async ({
// Skip this test for now
test.skip("Admin resets own password and logs in with new password", async ({
page,
}) => {
const { email: adminEmail, password: adminPassword } =

View File

@@ -88,11 +88,11 @@
}
},
{
"name": "Custom Assistants - Tools",
"path": "tools",
"pageTitle": "Tools",
"name": "Custom Assistants - Actions",
"path": "actions",
"pageTitle": "Actions",
"options": {
"paragraphText": "Tools allow assistants to retrieve information or take actions."
"paragraphText": "Actions allow assistants to retrieve information or take actions."
}
},
{