Compare commits

..

3 Commits

Author SHA1 Message Date
Weves
9b169350a9 Switch to monotonic 2025-03-07 19:04:47 -08:00
Weves
c1dbb073d0 Small tweaks 2025-03-07 15:53:22 -08:00
Weves
39bfc6ae16 Add basic memory logging 2025-03-07 15:46:14 -08:00
119 changed files with 1710 additions and 4349 deletions

View File

@@ -48,8 +48,6 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:

View File

@@ -114,4 +114,3 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
# Get database connection
connection = op.get_bind()
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:

View File

@@ -1,33 +0,0 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -1,51 +0,0 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -1,14 +1,10 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,5 +1,3 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -1,45 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -1,98 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,24 +1,269 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,96 +0,0 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -28,12 +26,11 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
@@ -63,72 +60,42 @@ async def get_or_provision_tenant(
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
# Early return for non-multi-tenant mode
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
# First, check if the user already has a tenant
tenant_id: str | None = None
try:
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# User doesn't exist, so we need to create a new tenant or assign an existing one
pass
try:
# Try to get a pre-provisioned tenant
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
return tenant_id
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
except Exception as e:
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
if not tenant_id:
raise HTTPException(
status_code=500,
detail="Failed to provision tenant. Please try again later.",
status_code=401, detail="User does not belong to an organization"
)
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
"""
Create a new tenant on-demand when no pre-provisioned tenants are available.
This is the fallback method when we can't use a pre-provisioned tenant.
"""
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
logger.info(f"Creating new tenant {tenant_id} for user {email}")
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
try:
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
@@ -142,25 +109,54 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@@ -191,74 +187,20 @@ async def notify_control_plane(
async def rollback_tenant_provisioning(tenant_id: str) -> None:
"""
Logic to rollback tenant provisioning on data plane.
Handles each step independently to ensure maximum cleanup even if some steps fail.
"""
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
# Track if any part of the rollback fails
rollback_errors = []
# 1. Try to drop the tenant's schema
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 2. Try to remove tenant mapping
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
logger.info(
f"Successfully removed user mappings for tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 3. If this tenant was in the available tenants table, remove it
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
available_tenant = (
db_session.query(AvailableTenant)
.filter(AvailableTenant.tenant_id == tenant_id)
.first()
)
if available_tenant:
db_session.delete(available_tenant)
db_session.commit()
logger.info(
f"Removed tenant {tenant_id} from available tenants table"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# Log summary of rollback operation
if rollback_errors:
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
else:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
@@ -411,155 +353,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None
async def get_available_tenant() -> str | None:
"""
Get an available pre-provisioned tenant from the NewAvailableTenant table.
Returns the tenant_id if one is available, None otherwise.
Uses row-level locking to prevent race conditions when multiple processes
try to get an available tenant simultaneously.
"""
if not MULTI_TENANT:
return None
with get_session_with_shared_schema() as db_session:
try:
db_session.begin()
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
available_tenant = (
db_session.query(AvailableTenant)
.order_by(AvailableTenant.date_created)
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
.first()
)
if available_tenant:
tenant_id = available_tenant.tenant_id
# Remove the tenant from the available tenants table
db_session.delete(available_tenant)
db_session.commit()
logger.info(f"Using pre-provisioned tenant {tenant_id}")
return tenant_id
else:
db_session.rollback()
return None
except Exception:
logger.exception("Error getting available tenant")
db_session.rollback()
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
Uses transaction handling to ensure atomicity and includes retry logic
for control plane notifications.
"""
# First, add the user to the tenant in a transaction
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)
def get_current_alembic_version(tenant_id: str) -> str:
"""Get the current Alembic version for a tenant."""
from alembic.runtime.migration import MigrationContext
from sqlalchemy import text
engine = get_sqlalchemy_engine()
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "head"

View File

@@ -1,67 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,39 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -1,90 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,56 +1,27 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -67,39 +38,13 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
)
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.rollback()
raise
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
@@ -131,187 +76,3 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -62,60 +62,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -146,17 +92,31 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -195,6 +155,7 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -278,51 +239,22 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@staticmethod
def create(
@@ -637,13 +569,6 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -153,8 +153,7 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,6 +1,5 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,7 +100,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -895,7 +894,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accessible_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1096,12 +1095,6 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1133,14 +1126,9 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1152,7 +1140,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -112,6 +112,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,6 +92,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -5,53 +5,40 @@ from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()

View File

@@ -167,16 +167,6 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted

View File

@@ -1,199 +0,0 @@
"""
Periodic tasks for tenant pre-provisioning.
"""
import asyncio
import datetime
import uuid
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
@shared_task(
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
queue=OnyxCeleryQueues.MONITORING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_available_tenants(self: Task) -> None:
"""
Check if we have enough pre-provisioned tenants available.
If not, trigger the pre-provisioning of new tenants.
"""
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
if not MULTI_TENANT:
task_logger.info(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
if not lock_check.acquire(blocking=False):
task_logger.info(
"Skipping check_available_tenants task because it is already running"
)
return
try:
# Get the current count of available tenants
with get_session_with_shared_schema() as db_session:
available_tenants_count = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
target_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
tenants_to_provision = max(
0, target_available_tenants - available_tenants_count
)
task_logger.info(
f"Available tenants: {available_tenants_count}, "
f"Target: {target_available_tenants}, "
f"To provision: {tenants_to_provision}"
)
# Trigger pre-provisioning tasks for each tenant needed
for _ in range(tenants_to_provision):
from celery import current_app
current_app.send_task(
OnyxCeleryTask.PRE_PROVISION_TENANT,
priority=OnyxCeleryPriority.LOW,
)
except Exception:
task_logger.exception("Error in check_available_tenants task")
finally:
lock_check.release()
@shared_task(
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
ignore_result=True,
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def pre_provision_tenant(self: Task) -> None:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
r = get_redis_client()
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
)
return
tenant_id: str | None = None
try:
# Generate a new tenant ID
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
# Create the schema for the new tenant
schema_created = create_schema_if_not_exists(tenant_id)
if schema_created:
task_logger.debug(f"Created schema for tenant: {tenant_id}")
else:
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
# Set up the tenant with all necessary configurations
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
asyncio.run(setup_tenant(tenant_id))
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
task_logger.debug(
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
)
# Store the pre-provisioned tenant in the database
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
with get_session_with_shared_schema() as db_session:
# Use a transaction to ensure atomicity
db_session.begin()
try:
new_tenant = AvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
except Exception:
db_session.rollback()
task_logger.error(
f"Failed to store pre-provisioned tenant: {tenant_id}",
exc_info=True,
)
raise
except Exception:
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
# If we have a tenant_id, attempt to rollback any partially completed provisioning
if tenant_id:
task_logger.info(
f"Rolling back failed tenant provisioning for: {tenant_id}"
)
try:
from ee.onyx.server.tenants.provisioning import (
rollback_tenant_provisioning,
)
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
lock_provision.release()

View File

@@ -1,13 +1,10 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -47,44 +44,9 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -374,8 +336,6 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -387,13 +347,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = Union[
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
]
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -643,6 +643,3 @@ MOCK_LLM_RESPONSE = (
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))

View File

@@ -76,7 +76,6 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"
@@ -322,8 +321,6 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -386,7 +383,6 @@ class OnyxCeleryTask:
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -403,9 +399,6 @@ class OnyxCeleryTask:
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = "pre_provision_tenant"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@@ -263,7 +263,6 @@ class ConfluenceConnector(
result = process_attachment(
self.confluence_client,
attachment,
page_id,
page_title,
self.image_analysis_llm,
)
@@ -367,7 +366,6 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
page_context=confluence_xml,
llm=self.image_analysis_llm,
)

View File

@@ -1,3 +1,4 @@
import io
import json
import time
from collections.abc import Callable
@@ -18,11 +19,17 @@ from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -801,6 +808,65 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
if TYPE_CHECKING:
@@ -85,35 +84,25 @@ class AttachmentProcessingResult(BaseModel):
error: str | None = None
def _make_attachment_link(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
download_link = ""
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
return download_link
return None
return resp.content
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
page_context: str,
llm: LLM | None,
) -> AttachmentProcessingResult:
@@ -133,52 +122,11 @@ def process_attachment(
error=f"Unsupported file type: {media_type}",
)
attachment_link = _make_attachment_link(
confluence_client, attachment, parent_content_id
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link"
)
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/") or not llm:
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
logger.info(
f"Downloading attachment: "
f"title={attachment['title']} "
f"length={attachment_size} "
f"link={attachment_link}"
)
# Download the attachment
resp: requests.Response = confluence_client._session.get(attachment_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
)
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
)
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
text=None, file_name=None, error="Failed to download attachment"
)
# Process image attachments with LLM if available
@@ -301,7 +249,6 @@ def _process_text_attachment(
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
page_context: str,
llm: LLM | None,
) -> tuple[str | None, str | None] | None:
@@ -319,9 +266,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(
confluence_client, attachment, page_id, page_context, llm
)
result = process_attachment(confluence_client, attachment, page_context, llm)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@@ -228,15 +228,10 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -205,15 +204,6 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,3 +1,4 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -31,7 +32,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}

View File

@@ -48,12 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
domain = "test" if credentials.get("is_sandbox") else None
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
domain=domain,
)
return None

View File

@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
@@ -706,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,

View File

@@ -2295,31 +2295,21 @@ class PublicBase(DeclarativeBase):
__abstract__ = True
# Strictly keeps track of the tenant that a given user will authenticate to.
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = ({"schema": "public"},)
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
class AvailableTenant(Base):
__tablename__ = "available_tenant"
"""
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
"""
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):
__tablename__ = "tenant_anonymous_user_path"

View File

@@ -1,5 +1,6 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -148,10 +149,11 @@ def delete_document_tags_for_documents__no_commit(
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
orphan_tags_query = select(Tag.id).where(
~db_session.query(Document__Tag.tag_id)
.filter(Document__Tag.tag_id == Tag.id)
.exists()
orphan_tags_query = (
select(Tag.id)
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
.group_by(Tag.id)
.having(func.count(Document__Tag.document_id) == 0)
)
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()

View File

@@ -234,8 +234,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()

View File

@@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -112,7 +112,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accessible_user
or depends_fn == current_chat_accesssible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):

View File

@@ -17,7 +17,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
@@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
@router.get("/connector-status")
def get_basic_connector_indexing_status(
user: User = Depends(current_chat_accessible_user),
user: User = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs_for_user(

View File

@@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -390,7 +390,7 @@ def get_image_generation_tool(
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),

View File

@@ -7,7 +7,7 @@ from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
@@ -191,7 +191,7 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [

View File

@@ -53,16 +53,6 @@ class UserPreferences(BaseModel):
temperature_override_enabled: bool | None = None
class TenantSnapshot(BaseModel):
tenant_id: str
number_of_users: int
class TenantInfo(BaseModel):
invitation: TenantSnapshot | None = None
new_tenant: TenantSnapshot | None = None
class UserInfo(BaseModel):
id: str
email: str
@@ -75,10 +65,9 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
team_name: str | None = None
organization_name: str | None = None
is_anonymous_user: bool | None = None
password_configured: bool | None = None
tenant_info: TenantInfo | None = None
@classmethod
def from_model(
@@ -87,9 +76,8 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
team_name: str | None = None,
organization_name: str | None = None,
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -111,7 +99,7 @@ class UserInfo(BaseModel):
temperature_override_enabled=user.temperature_override_enabled,
)
),
team_name=team_name,
organization_name=organization_name,
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
# where they previously had this set + used OIDC, and now they switched to
# basic auth are now constantly getting redirected back to the login page
@@ -121,7 +109,6 @@ class UserInfo(BaseModel):
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
is_anonymous_user=is_anonymous_user,
tenant_info=tenant_info,
)

View File

@@ -12,11 +12,13 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import SUPER_USERS
@@ -53,8 +55,6 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import TenantInfo
from onyx.server.manage.models import TenantSnapshot
from onyx.server.manage.models import UserByEmail
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPreferences
@@ -296,6 +296,13 @@ def bulk_invite_users(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
)(new_invited_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
status_code=400,
detail="User has already been invited to a Onyx organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
@@ -418,10 +425,6 @@ async def delete_user(
db_session.expunge(user_to_delete)
try:
tenant_id = get_current_tenant_id()
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
delete_user_from_db(user_to_delete, db_session)
logger.info(f"Deleted user {user_to_delete.email}")
@@ -550,8 +553,8 @@ def verify_user_logged_in(
if anonymous_user_enabled(tenant_id=tenant_id):
store = get_kv_store()
return fetch_no_auth_user(store, anonymous_user_enabled=True)
raise BasicAuthenticationError(detail="User Not Authenticated")
raise BasicAuthenticationError(detail="User Not Authenticated")
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
@@ -560,35 +563,16 @@ def verify_user_logged_in(
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
)
team_name = fetch_ee_implementation_or_noop(
organization_name = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user.email)
new_tenant: TenantSnapshot | None = None
tenant_invitation: TenantSnapshot | None = None
if MULTI_TENANT:
if team_name != get_current_tenant_id():
user_count = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_count", None
)(team_name)
new_tenant = TenantSnapshot(tenant_id=team_name, number_of_users=user_count)
tenant_invitation = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
)(user.email)
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
team_name=team_name,
tenant_info=TenantInfo(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
organization_name=organization_name,
)
return user_info

View File

@@ -49,9 +49,9 @@ class FullUserSnapshot(BaseModel):
)
class DisplayPriorityRequest(BaseModel):
display_priority_map: dict[int, int]
class InvitedUserSnapshot(BaseModel):
email: str
class DisplayPriorityRequest(BaseModel):
display_priority_map: dict[int, int]

View File

@@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import extract_headers
@@ -190,7 +190,7 @@ def update_chat_session_model(
def get_chat_session(
session_id: UUID,
is_shared: bool = False,
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> ChatSessionDetailResponse:
user_id = user.id if user is not None else None
@@ -246,7 +246,7 @@ def get_chat_session(
@router.post("/create-chat-session")
def create_new_chat_session(
chat_session_creation_request: ChatSessionCreationRequest,
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> CreateChatSessionID:
user_id = user.id if user is not None else None
@@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
) -> StreamingResponse:
@@ -473,7 +473,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None

View File

@@ -11,7 +11,7 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
def check_token_rate_limits(
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time

View File

@@ -32,15 +32,15 @@ class InCodeToolInfo(TypedDict):
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=SearchTool,
description="The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
in_code_tool_id=SearchTool.__name__,
display_name=SearchTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=ImageGenerationTool,
description=(
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
"The action will be used when the user asks the assistant to generate an image."
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
"The tool will be used when the user asks the assistant to generate an image."
),
in_code_tool_id=ImageGenerationTool.__name__,
display_name=ImageGenerationTool._DISPLAY_NAME,
@@ -51,7 +51,7 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=InternetSearchTool,
description=(
"The Internet Search Action allows the assistant "
"The Internet Search Tool allows the assistant "
"to perform internet searches for up-to-date information."
),
in_code_tool_id=InternetSearchTool.__name__,
@@ -98,7 +98,7 @@ def load_builtin_tools(db_session: Session) -> None:
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
if tool_id not in built_in_ids:
db_session.delete(tool)
logger.notice(f"Removed action no longer in built-in list: {tool.name}")
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
db_session.commit()
logger.notice("All built-in tools are loaded/verified.")

View File

@@ -1,43 +0,0 @@
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
def add_url_params(url: str, params: dict) -> str:
"""
Add parameters to a URL, handling existing parameters properly.
Args:
url: The original URL
params: Dictionary of parameters to add
Returns:
URL with added parameters
"""
# Parse the URL
parsed_url = urlparse(url)
# Get existing query parameters
query_params = parse_qs(parsed_url.query)
# Update with new parameters
for key, value in params.items():
query_params[key] = [value]
# Build the new query string
new_query = urlencode(query_params, doseq=True)
# Reconstruct the URL with the new query string
new_url = urlunparse(
(
parsed_url.scheme,
parsed_url.netloc,
parsed_url.path,
parsed_url.params,
new_query,
parsed_url.fragment,
)
)
return new_url

View File

@@ -1,4 +1,4 @@
black==23.7.0
black==23.3.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0
cohere==5.6.1

View File

@@ -54,7 +54,6 @@ class OnyxRedisCommand(Enum):
purge_vespa_syncing = "purge_vespa_syncing"
get_user_token = "get_user_token"
delete_user_token = "delete_user_token"
add_invited_user = "add_invited_user"
def __str__(self) -> str:
return self.value
@@ -164,21 +163,6 @@ def onyx_redis(
return 0
else:
return 2
elif command == OnyxRedisCommand.add_invited_user:
if not user_email:
logger.error("You must specify --user-email with add_invited_user")
return 1
current_invited_users = get_invited_users()
if user_email not in current_invited_users:
current_invited_users.append(user_email)
if dry_run:
logger.info(f"(DRY-RUN) Would add {user_email} to invited users")
else:
write_invited_users(current_invited_users)
logger.info(f"Added {user_email} to invited users")
else:
logger.info(f"{user_email} is already in the invited users list")
return 0
else:
pass
@@ -457,6 +441,23 @@ if __name__ == "__main__":
if args.tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id)
if args.command == "add_invited_user":
if not args.user_email:
print("Error: --user-email is required for add_invited_user command")
sys.exit(1)
current_invited_users = get_invited_users()
if args.user_email not in current_invited_users:
current_invited_users.append(args.user_email)
if args.dry_run:
print(f"(DRY-RUN) Would add {args.user_email} to invited users")
else:
write_invited_users(current_invited_users)
print(f"Added {args.user_email} to invited users")
else:
print(f"{args.user_email} is already in the invited users list")
sys.exit(0)
exitcode = onyx_redis(
command=args.command,
batch=args.batch,

View File

@@ -36,7 +36,6 @@ def confluence_connector() -> ConfluenceConnector:
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:

View File

@@ -28,7 +28,6 @@ def confluence_connector() -> ConfluenceConnector:
# This should never fail because even if the docs in the cloud change,
# the full doc ids retrieved should always be a subset of the slim doc ids
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_permissions(
confluence_connector: ConfluenceConnector,
) -> None:

View File

@@ -20,32 +20,29 @@ def gitbook_connector() -> GitbookConnector:
return connector
NUM_PAGES = 3
def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
doc_batch_generator = gitbook_connector.load_from_state()
# Get first batch of documents
doc_batch = next(doc_batch_generator)
assert len(doc_batch) == NUM_PAGES
assert len(doc_batch) > 0
# Verify first document structure
main_doc = doc_batch[0]
doc = doc_batch[0]
# Basic document properties
assert main_doc.id.startswith("gitbook-")
assert main_doc.semantic_identifier == "Acme Corp Internal Handbook"
assert main_doc.source == DocumentSource.GITBOOK
assert doc.id.startswith("gitbook-")
assert doc.semantic_identifier == "Acme Corp Internal Handbook"
assert doc.source == DocumentSource.GITBOOK
# Metadata checks
assert "path" in main_doc.metadata
assert "type" in main_doc.metadata
assert "kind" in main_doc.metadata
assert "path" in doc.metadata
assert "type" in doc.metadata
assert "kind" in doc.metadata
# Section checks
assert len(main_doc.sections) == 1
section = main_doc.sections[0]
assert len(doc.sections) == 1
section = doc.sections[0]
# Content specific checks
content = section.text
@@ -77,23 +74,8 @@ def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
assert section.link # Should have a URL
nested1 = doc_batch[1]
assert nested1.id.startswith("gitbook-")
assert nested1.semantic_identifier == "Nested1"
assert len(nested1.sections) == 1
# extra newlines at the end, remove them to make test easier
assert nested1.sections[0].text.strip() == "nested1"
assert nested1.source == DocumentSource.GITBOOK
nested2 = doc_batch[2]
assert nested2.id.startswith("gitbook-")
assert nested2.semantic_identifier == "Nested2"
assert len(nested2.sections) == 1
assert nested2.sections[0].text.strip() == "nested2"
assert nested2.source == DocumentSource.GITBOOK
# Time-based polling test
current_time = time.time()
poll_docs = gitbook_connector.poll_source(0, current_time)
poll_batch = next(poll_docs)
assert len(poll_batch) == NUM_PAGES
assert len(poll_batch) > 0

View File

@@ -1,128 +0,0 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.notion.connector import NotionConnector
@pytest.fixture
def notion_connector() -> NotionConnector:
"""Create a NotionConnector with credentials from environment variables"""
connector = NotionConnector()
connector.load_credentials(
{
"notion_integration_token": os.environ["NOTION_INTEGRATION_TOKEN"],
}
)
return connector
def test_notion_connector_basic(notion_connector: NotionConnector) -> None:
"""Test the NotionConnector with a real Notion page.
Uses a Notion workspace under the onyx-test.com domain.
"""
doc_batch_generator = notion_connector.poll_source(0, time.time())
# Get first batch of documents
doc_batch = next(doc_batch_generator)
assert (
len(doc_batch) == 5
), "Expected exactly 5 documents (root, two children, table entry, and table entry child)"
# Find root and child documents by semantic identifier
root_doc = None
child1_doc = None
child2_doc = None
table_entry_doc = None
table_entry_child_doc = None
for doc in doc_batch:
if doc.semantic_identifier == "Root":
root_doc = doc
elif doc.semantic_identifier == "Child1":
child1_doc = doc
elif doc.semantic_identifier == "Child2":
child2_doc = doc
elif doc.semantic_identifier == "table-entry01":
table_entry_doc = doc
elif doc.semantic_identifier == "Child-table-entry01":
table_entry_child_doc = doc
assert root_doc is not None, "Root document not found"
assert child1_doc is not None, "Child1 document not found"
assert child2_doc is not None, "Child2 document not found"
assert table_entry_doc is not None, "Table entry document not found"
assert table_entry_child_doc is not None, "Table entry child document not found"
# Verify root document structure
assert root_doc.id is not None
assert root_doc.source == DocumentSource.NOTION
# Section checks for root
assert len(root_doc.sections) == 1
root_section = root_doc.sections[0]
# Content specific checks for root
assert root_section.text == "\nroot"
assert root_section.link is not None
assert root_section.link.startswith("https://www.notion.so/")
# Verify child1 document structure
assert child1_doc.id is not None
assert child1_doc.source == DocumentSource.NOTION
# Section checks for child1
assert len(child1_doc.sections) == 1
child1_section = child1_doc.sections[0]
# Content specific checks for child1
assert child1_section.text == "\nchild1"
assert child1_section.link is not None
assert child1_section.link.startswith("https://www.notion.so/")
# Verify child2 document structure (includes database)
assert child2_doc.id is not None
assert child2_doc.source == DocumentSource.NOTION
# Section checks for child2
assert len(child2_doc.sections) == 2 # One for content, one for database
child2_section = child2_doc.sections[0]
child2_db_section = child2_doc.sections[1]
# Content specific checks for child2
assert child2_section.text == "\nchild2"
assert child2_section.link is not None
assert child2_section.link.startswith("https://www.notion.so/")
# Database section checks for child2
assert child2_db_section.text.strip() != "" # Should contain some database content
assert child2_db_section.link is not None
assert child2_db_section.link.startswith("https://www.notion.so/")
# Verify table entry document structure
assert table_entry_doc.id is not None
assert table_entry_doc.source == DocumentSource.NOTION
# Section checks for table entry
assert len(table_entry_doc.sections) == 1
table_entry_section = table_entry_doc.sections[0]
# Content specific checks for table entry
assert table_entry_section.text == "\ntable-entry01"
assert table_entry_section.link is not None
assert table_entry_section.link.startswith("https://www.notion.so/")
# Verify table entry child document structure
assert table_entry_child_doc.id is not None
assert table_entry_child_doc.source == DocumentSource.NOTION
# Section checks for table entry child
assert len(table_entry_child_doc.sections) == 1
table_entry_child_section = table_entry_child_doc.sections[0]
# Content specific checks for table entry child
assert table_entry_child_section.text == "\nchild-table-entry01"
assert table_entry_child_section.link is not None
assert table_entry_child_section.link.startswith("https://www.notion.so/")

View File

@@ -2,7 +2,7 @@ FROM python:3.11.7-slim-bookworm
WORKDIR /app
RUN pip install "pydantic-core>=2.28.0" fastapi uvicorn
RUN pip install fastapi uvicorn
COPY ./main.py /app/main.py

View File

@@ -1,8 +1,3 @@
from typing import Any
import pytest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
@@ -22,58 +17,3 @@ def test_send_message_simple_with_history(reset: None) -> None:
)
assert len(response.full_message) > 0
@pytest.mark.skip(
reason="enable for autorun when we have a testing environment with semantically useful data"
)
def test_send_message_simple_with_history_buffered() -> None:
import requests
API_KEY = "" # fill in for this to work
headers = {}
headers["Authorization"] = f"Bearer {API_KEY}"
req: dict[str, Any] = {}
req["persona_id"] = 0
req["description"] = "test_send_message_simple_with_history_buffered"
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session", headers=headers, json=req
)
chat_session_id = response.json()["chat_session_id"]
req = {}
req["chat_session_id"] = chat_session_id
req["message"] = "What does onyx do?"
req["use_agentic_search"] = True
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-api", headers=headers, json=req
)
r_json = response.json()
# all of these should exist and be greater than length 1
assert len(r_json.get("answer", "")) > 0
assert len(r_json.get("agent_sub_questions", "")) > 0
assert len(r_json.get("agent_answers")) > 0
assert len(r_json.get("agent_sub_queries")) > 0
assert "agent_refined_answer_improvement" in r_json
# top level answer should match the one we select out of agent_answers
answer_level = 0
agent_level_answer = ""
agent_refined_answer_improvement = r_json.get("agent_refined_answer_improvement")
if agent_refined_answer_improvement:
answer_level = len(r_json["agent_answers"]) - 1
answers = r_json["agent_answers"][str(answer_level)]
for answer in answers:
if answer["answer_type"] == "agent_level_answer":
agent_level_answer = answer["answer"]
break
assert r_json["answer"] == agent_level_answer
assert response.status_code == 200

View File

@@ -1095,7 +1095,8 @@ export function AssistantEditor({
{values.is_public ? (
<p className="text-sm text-text-dark">
Anyone from your team can view and use this assistant
Anyone from your organization can view and use this
assistant
</p>
) : (
<>

View File

@@ -177,11 +177,6 @@ export function PersonasTable() {
entityName={personaToToggleDefault.name}
onClose={closeDefaultModal}
onSubmit={handleToggleDefault}
actionText={
personaToToggleDefault.is_default_persona
? "remove the featured status of"
: "set as featured"
}
actionButtonText={
personaToToggleDefault.is_default_persona
? "Remove Featured"

View File

@@ -121,7 +121,7 @@ function Main() {
);
}
export default function Page() {
function Page() {
return (
<div className="mx-auto container">
<AdminPageTitle
@@ -132,3 +132,5 @@ export default function Page() {
</div>
);
}
export default Page;

View File

@@ -1,3 +1,4 @@
import { Button } from "@/components/Button";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import React, { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
@@ -7,14 +8,10 @@ import { adminDeleteCredential } from "@/lib/credential";
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
import Cookies from "js-cookie";
import {
TextFormField,
SectionHeader,
SubLabel,
} from "@/components/admin/connectors/Field";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Form, Formik } from "formik";
import { User } from "@/lib/types";
import { Button } from "@/components/ui/button";
import { Button as TremorButton } from "@/components/ui/button";
import {
Credential,
GoogleDriveCredentialJson,
@@ -23,15 +20,6 @@ import {
import { refreshAllGoogleData } from "@/lib/googleConnector";
import { ValidSources } from "@/lib/types";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
import {
FiFile,
FiUpload,
FiTrash2,
FiCheck,
FiLink,
FiAlertTriangle,
} from "react-icons/fi";
import { cn, truncateString } from "@/lib/utils";
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
@@ -43,202 +31,126 @@ export const DriveJsonUpload = ({
onSuccess?: () => void;
}) => {
const { mutate } = useSWRConfig();
const [isUploading, setIsUploading] = useState(false);
const [fileName, setFileName] = useState<string | undefined>();
const [isDragging, setIsDragging] = useState(false);
const handleFileUpload = async (file: File) => {
setIsUploading(true);
setFileName(file.name);
const reader = new FileReader();
reader.onload = async (loadEvent) => {
if (!loadEvent?.target?.result) {
setIsUploading(false);
return;
}
const credentialJsonStr = loadEvent.target.result as string;
// Check credential type
let credentialFileType: GoogleDriveCredentialJsonTypes;
try {
const appCredentialJson = JSON.parse(credentialJsonStr);
if (appCredentialJson.web) {
credentialFileType = "authorized_user";
} else if (appCredentialJson.type === "service_account") {
credentialFileType = "service_account";
} else {
throw new Error(
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
);
}
} catch (e) {
setPopup({
message: `Invalid file provided - ${e}`,
type: "error",
});
setIsUploading(false);
return;
}
if (credentialFileType === "authorized_user") {
const response = await fetch(
"/api/manage/admin/connector/google-drive/app-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded app credentials",
type: "success",
});
mutate("/api/manage/admin/connector/google-drive/app-credential");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload app credentials - ${errorMsg}`,
type: "error",
});
}
}
if (credentialFileType === "service_account") {
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-key",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded service account key",
type: "success",
});
mutate(
"/api/manage/admin/connector/google-drive/service-account-key"
);
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload service account key - ${errorMsg}`,
type: "error",
});
}
}
setIsUploading(false);
};
reader.readAsText(file);
};
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
if (!isUploading) {
setIsDragging(true);
}
};
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
setIsDragging(false);
};
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
};
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
setIsDragging(false);
if (isUploading) return;
const files = e.dataTransfer.files;
if (files.length > 0) {
const file = files[0];
if (file.type === "application/json" || file.name.endsWith(".json")) {
handleFileUpload(file);
} else {
setPopup({
message: "Please upload a JSON file",
type: "error",
});
}
}
};
const [credentialJsonStr, setCredentialJsonStr] = useState<
string | undefined
>();
return (
<div className="flex flex-col mt-4">
<div className="flex items-center">
<div className="relative flex flex-1 items-center">
<label
className={cn(
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
isUploading
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
: isDragging
? "bg-background-50/50 border-primary dark:border-primary"
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
)}
onDragEnter={handleDragEnter}
onDragLeave={handleDragLeave}
onDragOver={handleDragOver}
onDrop={handleDrop}
>
<div className="flex items-center space-x-2">
{isUploading ? (
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
) : (
<FiFile className="h-4 w-4 text-text-500" />
)}
<span className="text-sm text-text-500">
{isUploading
? `Uploading ${truncateString(fileName || "file", 50)}...`
: isDragging
? "Drop JSON file here"
: truncateString(
fileName || "Select or drag JSON credentials file...",
50
)}
</span>
</div>
<input
className="sr-only"
type="file"
accept=".json"
disabled={isUploading}
onChange={(event) => {
if (!event.target.files?.length) {
return;
}
const file = event.target.files[0];
handleFileUpload(file);
}}
/>
</label>
</div>
</div>
</div>
<>
<input
className={
"mr-3 text-sm text-text-900 border border-background-300 " +
"cursor-pointer bg-backgrournd dark:text-text-400 focus:outline-none " +
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
}
type="file"
accept=".json"
onChange={(event) => {
if (!event.target.files) {
return;
}
const file = event.target.files[0];
const reader = new FileReader();
reader.onload = function (loadEvent) {
if (!loadEvent?.target?.result) {
return;
}
const fileContents = loadEvent.target.result;
setCredentialJsonStr(fileContents as string);
};
reader.readAsText(file);
}}
/>
<Button
disabled={!credentialJsonStr}
onClick={async () => {
let credentialFileType: GoogleDriveCredentialJsonTypes;
try {
const appCredentialJson = JSON.parse(credentialJsonStr!);
if (appCredentialJson.web) {
credentialFileType = "authorized_user";
} else if (appCredentialJson.type === "service_account") {
credentialFileType = "service_account";
} else {
throw new Error(
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
);
}
} catch (e) {
setPopup({
message: `Invalid file provided - ${e}`,
type: "error",
});
return;
}
if (credentialFileType === "authorized_user") {
const response = await fetch(
"/api/manage/admin/connector/google-drive/app-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded app credentials",
type: "success",
});
mutate("/api/manage/admin/connector/google-drive/app-credential");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload app credentials - ${errorMsg}`,
type: "error",
});
}
}
if (credentialFileType === "service_account") {
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-key",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded service account key",
type: "success",
});
mutate(
"/api/manage/admin/connector/google-drive/service-account-key"
);
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload service account key - ${errorMsg}`,
type: "error",
});
}
}
}}
>
Upload
</Button>
</>
);
};
@@ -248,7 +160,6 @@ interface DriveJsonUploadSectionProps {
serviceAccountCredentialData?: { service_account_email: string };
isAdmin: boolean;
onSuccess?: () => void;
existingAuthCredential?: boolean;
}
export const DriveJsonUploadSection = ({
@@ -257,7 +168,6 @@ export const DriveJsonUploadSection = ({
serviceAccountCredentialData,
isAdmin,
onSuccess,
existingAuthCredential,
}: DriveJsonUploadSectionProps) => {
const { mutate } = useSWRConfig();
const router = useRouter();
@@ -267,7 +177,6 @@ export const DriveJsonUploadSection = ({
const [localAppCredentialData, setLocalAppCredentialData] =
useState(appCredentialData);
// Update local state when props change
useEffect(() => {
setLocalServiceAccountData(serviceAccountCredentialData);
setLocalAppCredentialData(appCredentialData);
@@ -281,135 +190,153 @@ export const DriveJsonUploadSection = ({
}
};
if (!isAdmin) {
if (localServiceAccountData?.service_account_email) {
return (
<div>
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
<p className="text-sm">
Curators are unable to set up the Google Drive credentials. To add a
Google Drive connector, please contact an administrator.
<div className="mt-2 text-sm">
<div>
Found existing service account key with the following <b>Email:</b>
<p className="italic mt-1">
{localServiceAccountData.service_account_email}
</p>
</div>
{isAdmin ? (
<>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-key",
{
method: "DELETE",
}
);
if (response.ok) {
mutate(
"/api/manage/admin/connector/google-drive/service-account-key"
);
mutate(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
);
setPopup({
message: "Successfully deleted service account key",
type: "success",
});
setLocalServiceAccountData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete service account key - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
</>
) : (
<>
<div className="mt-4 mb-1">
To change these credentials, please contact an administrator.
</div>
</>
)}
</div>
);
}
if (localAppCredentialData?.client_id) {
return (
<div className="mt-2 text-sm">
<div>
Found existing app credentials with the following <b>Client ID:</b>
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
</div>
{isAdmin ? (
<>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/google-drive/app-credential",
{
method: "DELETE",
}
);
if (response.ok) {
mutate(
"/api/manage/admin/connector/google-drive/app-credential"
);
mutate(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
);
setPopup({
message: "Successfully deleted app credentials",
type: "success",
});
setLocalAppCredentialData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete app credential - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
</>
) : (
<div className="mt-4 mb-1">
To change these credentials, please contact an administrator.
</div>
)}
</div>
);
}
if (!isAdmin) {
return (
<div className="mt-2">
<p className="text-sm mb-2">
Curators are unable to set up the google drive credentials. To add a
Google Drive connector, please contact an administrator.
</p>
</div>
);
}
return (
<div>
<p className="text-sm mb-3">
To connect your Google Drive, create credentials (either OAuth App or
Service Account), download the JSON file, and upload it below.
</p>
<div className="mb-4">
<div className="mt-2">
<p className="text-sm mb-2">
Follow the guide{" "}
<a
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
className="text-link"
target="_blank"
href="https://docs.onyx.app/connectors/google_drive#authorization"
rel="noreferrer"
>
<FiLink className="h-3 w-3" />
View detailed setup instructions
</a>
</div>
{(localServiceAccountData?.service_account_email ||
localAppCredentialData?.client_id) && (
<div className="mb-4">
<div className="relative flex flex-1 items-center">
<label
className={cn(
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
false
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
)}
>
<div className="flex items-center space-x-2">
{false ? (
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
) : (
<FiFile className="h-4 w-4 text-text-500" />
)}
<span className="text-sm text-text-500">
{truncateString(
localServiceAccountData?.service_account_email ||
localAppCredentialData?.client_id ||
"",
50
)}
</span>
</div>
</label>
</div>
{isAdmin && !existingAuthCredential && (
<div className="mt-2">
<Button
variant="destructive"
type="button"
onClick={async () => {
const endpoint =
localServiceAccountData?.service_account_email
? "/api/manage/admin/connector/google-drive/service-account-key"
: "/api/manage/admin/connector/google-drive/app-credential";
const response = await fetch(endpoint, {
method: "DELETE",
});
if (response.ok) {
mutate(endpoint);
// Also mutate the credential endpoints to ensure Step 2 is reset
mutate(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
);
// Add additional mutations to refresh all credential-related endpoints
mutate(
"/api/manage/admin/connector/google-drive/credentials"
);
mutate(
"/api/manage/admin/connector/google-drive/public-credential"
);
mutate(
"/api/manage/admin/connector/google-drive/service-account-credential"
);
setPopup({
message: `Successfully deleted ${
localServiceAccountData
? "service account key"
: "app credentials"
}`,
type: "success",
});
// Immediately update local state
if (localServiceAccountData) {
setLocalServiceAccountData(undefined);
} else {
setLocalAppCredentialData(undefined);
}
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete credentials - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete Credentials
</Button>
</div>
)}
</div>
)}
{!(
localServiceAccountData?.service_account_email ||
localAppCredentialData?.client_id
) && <DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />}
here
</a>{" "}
to either (1) setup a google OAuth App in your company workspace or (2)
create a Service Account.
<br />
<br />
Download the credentials JSON if choosing option (1) or the Service
Account key JSON if chooosing option (2), and upload it here.
</p>
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
</div>
);
};
@@ -464,7 +391,6 @@ export const DriveAuthSection = ({
user,
}: DriveCredentialSectionProps) => {
const router = useRouter();
const [isAuthenticating, setIsAuthenticating] = useState(false);
const [localServiceAccountData, setLocalServiceAccountData] = useState(
serviceAccountKeyData
);
@@ -479,7 +405,6 @@ export const DriveAuthSection = ({
setLocalGoogleDriveServiceAccountCredential,
] = useState(googleDriveServiceAccountCredential);
// Update local state when props change
useEffect(() => {
setLocalServiceAccountData(serviceAccountKeyData);
setLocalAppCredentialData(appCredentialData);
@@ -499,181 +424,126 @@ export const DriveAuthSection = ({
localGoogleDriveServiceAccountCredential;
if (existingCredential) {
return (
<div>
<div className="mt-4">
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
<div className="flex-1">
<span className="font-medium block">Authentication Complete</span>
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
Your Google Drive credentials have been successfully uploaded
and authenticated.
</p>
</div>
</div>
<Button
variant="destructive"
type="button"
onClick={async () => {
handleRevokeAccess(
connectorAssociated,
setPopup,
existingCredential,
refreshCredentials
);
}}
>
Revoke Access
</Button>
</div>
</div>
);
}
// If no credentials are uploaded, show message to complete step 1 first
if (
!localServiceAccountData?.service_account_email &&
!localAppCredentialData?.client_id
) {
return (
<div>
<SectionHeader>Google Drive Authentication</SectionHeader>
<div className="mt-4">
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
<p className="text-sm">
Please complete Step 1 by uploading either OAuth credentials or a
Service Account key before proceeding with authentication.
</p>
</div>
</div>
</div>
<>
<p className="mb-2 text-sm">
<i>Uploaded and authenticated credential already exists!</i>
</p>
<Button
onClick={async () => {
handleRevokeAccess(
connectorAssociated,
setPopup,
existingCredential,
refreshCredentials
);
}}
>
Revoke Access
</Button>
</>
);
}
if (localServiceAccountData?.service_account_email) {
return (
<div>
<div className="mt-4">
<Formik
initialValues={{
google_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_primary_admin: Yup.string()
.email("Must be a valid email")
.required("Required"),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
google_primary_admin: values.google_primary_admin,
}),
}
);
if (response.ok) {
setPopup({
message: "Successfully created service account credential",
type: "success",
});
refreshCredentials();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to create service account credential - ${errorMsg}`,
type: "error",
});
}
} catch (error) {
setPopup({
message: `Failed to create service account credential - ${error}`,
type: "error",
});
} finally {
formikHelpers.setSubmitting(false);
<Formik
initialValues={{
google_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_primary_admin: Yup.string().required(
"User email is required"
),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
const response = await fetch(
"/api/manage/admin/connector/google-drive/service-account-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
google_primary_admin: values.google_primary_admin,
}),
}
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_primary_admin"
label="Primary Admin Email:"
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
/>
<div className="flex">
<Button type="submit" disabled={isSubmitting}>
{isSubmitting ? "Creating..." : "Create Credential"}
</Button>
</div>
</Form>
)}
</Formik>
</div>
);
if (response.ok) {
setPopup({
message: "Successfully created service account credential",
type: "success",
});
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to create service account credential - ${errorMsg}`,
type: "error",
});
}
refreshCredentials();
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_primary_admin"
label="Primary Admin Email:"
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
/>
<div className="flex">
<TremorButton type="submit" disabled={isSubmitting}>
Create Credential
</TremorButton>
</div>
</Form>
)}
</Formik>
</div>
);
}
if (localAppCredentialData?.client_id) {
return (
<div>
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
<p className="text-sm">
Next, you need to authenticate with Google Drive via OAuth. This
gives us read access to the documents you have access to in your
Google Drive account.
</p>
</div>
<div className="text-sm mb-4">
<p className="mb-2">
Next, you must provide credentials via OAuth. This gives us read
access to the docs you have access to in your google drive account.
</p>
<Button
disabled={isAuthenticating}
onClick={async () => {
setIsAuthenticating(true);
try {
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
isAdmin: true,
name: "OAuth (uploaded)",
});
if (authUrl) {
// cookie used by callback to determine where to finally redirect to
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
path: "/",
});
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
isAdmin: true,
name: "OAuth (uploaded)",
});
if (authUrl) {
router.push(authUrl);
} else {
setPopup({
message: errorMsg,
type: "error",
});
setIsAuthenticating(false);
}
} catch (error) {
setPopup({
message: `Failed to authenticate with Google Drive - ${error}`,
type: "error",
});
setIsAuthenticating(false);
router.push(authUrl);
return;
}
setPopup({
message: errorMsg,
type: "error",
});
}}
>
{isAuthenticating
? "Authenticating..."
: "Authenticate with Google Drive"}
Authenticate with Google Drive
</Button>
</div>
);
}
// This code path should not be reached with the new conditions above
return null;
// case where no keys have been uploaded in step 1
return (
<p className="text-sm">
Please upload either a OAuth Client Credential JSON or a Google Drive
Service Account Key JSON in Step 1 before moving onto Step 2.
</p>
);
};

View File

@@ -165,10 +165,6 @@ const GDriveMain = ({
serviceAccountCredentialData={serviceAccountKeyData}
isAdmin={isAdmin}
onSuccess={handleRefresh}
existingAuthCredential={Boolean(
googleDrivePublicUploadedCredential ||
googleDriveServiceAccountCredential
)}
/>
{isAdmin &&

View File

@@ -1,4 +1,4 @@
import { Button } from "@/components/ui/button";
import { Button } from "@/components/Button";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import React, { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
@@ -8,11 +8,7 @@ import { adminDeleteCredential } from "@/lib/credential";
import { setupGmailOAuth } from "@/lib/gmail";
import { GMAIL_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
import Cookies from "js-cookie";
import {
TextFormField,
SectionHeader,
SubLabel,
} from "@/components/admin/connectors/Field";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Form, Formik } from "formik";
import { User } from "@/lib/types";
import CardSection from "@/components/admin/CardSection";
@@ -24,19 +20,10 @@ import {
import { refreshAllGoogleData } from "@/lib/googleConnector";
import { ValidSources } from "@/lib/types";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
import {
FiFile,
FiUpload,
FiTrash2,
FiCheck,
FiLink,
FiAlertTriangle,
} from "react-icons/fi";
import { cn, truncateString } from "@/lib/utils";
type GmailCredentialJsonTypes = "authorized_user" | "service_account";
const GmailCredentialUpload = ({
const DriveJsonUpload = ({
setPopup,
onSuccess,
}: {
@@ -44,210 +31,134 @@ const GmailCredentialUpload = ({
onSuccess?: () => void;
}) => {
const { mutate } = useSWRConfig();
const [isUploading, setIsUploading] = useState(false);
const [fileName, setFileName] = useState<string | undefined>();
const [isDragging, setIsDragging] = useState(false);
const handleFileUpload = async (file: File) => {
setIsUploading(true);
setFileName(file.name);
const reader = new FileReader();
reader.onload = async (loadEvent) => {
if (!loadEvent?.target?.result) {
setIsUploading(false);
return;
}
const credentialJsonStr = loadEvent.target.result as string;
// Check credential type
let credentialFileType: GmailCredentialJsonTypes;
try {
const appCredentialJson = JSON.parse(credentialJsonStr);
if (appCredentialJson.web) {
credentialFileType = "authorized_user";
} else if (appCredentialJson.type === "service_account") {
credentialFileType = "service_account";
} else {
throw new Error(
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
);
}
} catch (e) {
setPopup({
message: `Invalid file provided - ${e}`,
type: "error",
});
setIsUploading(false);
return;
}
if (credentialFileType === "authorized_user") {
const response = await fetch(
"/api/manage/admin/connector/gmail/app-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded app credentials",
type: "success",
});
mutate("/api/manage/admin/connector/gmail/app-credential");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload app credentials - ${errorMsg}`,
type: "error",
});
}
}
if (credentialFileType === "service_account") {
const response = await fetch(
"/api/manage/admin/connector/gmail/service-account-key",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded service account key",
type: "success",
});
mutate("/api/manage/admin/connector/gmail/service-account-key");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload service account key - ${errorMsg}`,
type: "error",
});
}
}
setIsUploading(false);
};
reader.readAsText(file);
};
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
if (!isUploading) {
setIsDragging(true);
}
};
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
setIsDragging(false);
};
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
};
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
e.preventDefault();
e.stopPropagation();
setIsDragging(false);
if (isUploading) return;
const files = e.dataTransfer.files;
if (files.length > 0) {
const file = files[0];
if (file.type === "application/json" || file.name.endsWith(".json")) {
handleFileUpload(file);
} else {
setPopup({
message: "Please upload a JSON file",
type: "error",
});
}
}
};
const [credentialJsonStr, setCredentialJsonStr] = useState<
string | undefined
>();
return (
<div className="flex flex-col mt-4">
<div className="flex items-center">
<div className="relative flex flex-1 items-center">
<label
className={cn(
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
isUploading
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
: isDragging
? "bg-background-50/50 border-primary dark:border-primary"
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
)}
onDragEnter={handleDragEnter}
onDragLeave={handleDragLeave}
onDragOver={handleDragOver}
onDrop={handleDrop}
>
<div className="flex items-center space-x-2">
{isUploading ? (
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
) : (
<FiFile className="h-4 w-4 text-text-500" />
)}
<span className="text-sm text-text-500">
{isUploading
? `Uploading ${truncateString(fileName || "file", 50)}...`
: isDragging
? "Drop JSON file here"
: truncateString(
fileName || "Select or drag JSON credentials file...",
50
)}
</span>
</div>
<input
className="sr-only"
type="file"
accept=".json"
disabled={isUploading}
onChange={(event) => {
if (!event.target.files?.length) {
return;
}
const file = event.target.files[0];
handleFileUpload(file);
}}
/>
</label>
</div>
</div>
</div>
<>
<input
className={
"mr-3 text-sm text-text-900 border border-background-300 overflow-visible " +
"cursor-pointer bg-background dark:text-text-400 focus:outline-none " +
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
}
type="file"
accept=".json"
onChange={(event) => {
if (!event.target.files) {
return;
}
const file = event.target.files[0];
const reader = new FileReader();
reader.onload = function (loadEvent) {
if (!loadEvent?.target?.result) {
return;
}
const fileContents = loadEvent.target.result;
setCredentialJsonStr(fileContents as string);
};
reader.readAsText(file);
}}
/>
<Button
disabled={!credentialJsonStr}
onClick={async () => {
// check if the JSON is a app credential or a service account credential
let credentialFileType: GmailCredentialJsonTypes;
try {
const appCredentialJson = JSON.parse(credentialJsonStr!);
if (appCredentialJson.web) {
credentialFileType = "authorized_user";
} else if (appCredentialJson.type === "service_account") {
credentialFileType = "service_account";
} else {
throw new Error(
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
);
}
} catch (e) {
setPopup({
message: `Invalid file provided - ${e}`,
type: "error",
});
return;
}
if (credentialFileType === "authorized_user") {
const response = await fetch(
"/api/manage/admin/connector/gmail/app-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded app credentials",
type: "success",
});
mutate("/api/manage/admin/connector/gmail/app-credential");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload app credentials - ${errorMsg}`,
type: "error",
});
}
}
if (credentialFileType === "service_account") {
const response = await fetch(
"/api/manage/admin/connector/gmail/service-account-key",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: credentialJsonStr,
}
);
if (response.ok) {
setPopup({
message: "Successfully uploaded service account key",
type: "success",
});
mutate("/api/manage/admin/connector/gmail/service-account-key");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload service account key - ${errorMsg}`,
type: "error",
});
}
}
}}
>
Upload
</Button>
</>
);
};
interface GmailJsonUploadSectionProps {
interface DriveJsonUploadSectionProps {
setPopup: (popupSpec: PopupSpec | null) => void;
appCredentialData?: { client_id: string };
serviceAccountCredentialData?: { service_account_email: string };
isAdmin: boolean;
onSuccess?: () => void;
existingAuthCredential?: boolean;
}
export const GmailJsonUploadSection = ({
@@ -256,8 +167,7 @@ export const GmailJsonUploadSection = ({
serviceAccountCredentialData,
isAdmin,
onSuccess,
existingAuthCredential,
}: GmailJsonUploadSectionProps) => {
}: DriveJsonUploadSectionProps) => {
const { mutate } = useSWRConfig();
const router = useRouter();
const [localServiceAccountData, setLocalServiceAccountData] = useState(
@@ -280,138 +190,156 @@ export const GmailJsonUploadSection = ({
}
};
if (!isAdmin) {
if (localServiceAccountData?.service_account_email) {
return (
<div>
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
<p className="text-sm">
Curators are unable to set up the Gmail credentials. To add a Gmail
connector, please contact an administrator.
<div className="mt-2 text-sm">
<div>
Found existing service account key with the following <b>Email:</b>
<p className="italic mt-1">
{localServiceAccountData.service_account_email}
</p>
</div>
{isAdmin ? (
<>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/gmail/service-account-key",
{
method: "DELETE",
}
);
if (response.ok) {
mutate(
"/api/manage/admin/connector/gmail/service-account-key"
);
// Also mutate the credential endpoints to ensure Step 2 is reset
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
setPopup({
message: "Successfully deleted service account key",
type: "success",
});
// Immediately update local state
setLocalServiceAccountData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete service account key - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
</>
) : (
<>
<div className="mt-4 mb-1">
To change these credentials, please contact an administrator.
</div>
</>
)}
</div>
);
}
if (localAppCredentialData?.client_id) {
return (
<div className="mt-2 text-sm">
<div>
Found existing app credentials with the following <b>Client ID:</b>
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
</div>
{isAdmin ? (
<>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/gmail/app-credential",
{
method: "DELETE",
}
);
if (response.ok) {
mutate("/api/manage/admin/connector/gmail/app-credential");
// Also mutate the credential endpoints to ensure Step 2 is reset
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
setPopup({
message: "Successfully deleted app credentials",
type: "success",
});
// Immediately update local state
setLocalAppCredentialData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete app credential - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
</>
) : (
<div className="mt-4 mb-1">
To change these credentials, please contact an administrator.
</div>
)}
</div>
);
}
if (!isAdmin) {
return (
<div className="mt-2">
<p className="text-sm mb-2">
Curators are unable to set up the Gmail credentials. To add a Gmail
connector, please contact an administrator.
</p>
</div>
);
}
return (
<div>
<p className="text-sm mb-3">
To connect your Gmail, create credentials (either OAuth App or Service
Account), download the JSON file, and upload it below.
</p>
<div className="mb-4">
<div className="mt-2">
<p className="text-sm mb-2">
Follow the guide{" "}
<a
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
className="text-link"
target="_blank"
href="https://docs.onyx.app/connectors/gmail#authorization"
rel="noreferrer"
>
<FiLink className="h-3 w-3" />
View detailed setup instructions
</a>
</div>
{(localServiceAccountData?.service_account_email ||
localAppCredentialData?.client_id) && (
<div className="mb-4">
<div className="relative flex flex-1 items-center">
<label
className={cn(
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
false
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
)}
>
<div className="flex items-center space-x-2">
{false ? (
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
) : (
<FiFile className="h-4 w-4 text-text-500" />
)}
<span className="text-sm text-text-500">
{truncateString(
localServiceAccountData?.service_account_email ||
localAppCredentialData?.client_id ||
"",
50
)}
</span>
</div>
</label>
</div>
{isAdmin && !existingAuthCredential && (
<div className="mt-2">
<Button
variant="destructive"
type="button"
onClick={async () => {
const endpoint =
localServiceAccountData?.service_account_email
? "/api/manage/admin/connector/gmail/service-account-key"
: "/api/manage/admin/connector/gmail/app-credential";
const response = await fetch(endpoint, {
method: "DELETE",
});
if (response.ok) {
mutate(endpoint);
// Also mutate the credential endpoints to ensure Step 2 is reset
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
// Add additional mutations to refresh all credential-related endpoints
mutate("/api/manage/admin/connector/gmail/credentials");
mutate(
"/api/manage/admin/connector/gmail/public-credential"
);
mutate(
"/api/manage/admin/connector/gmail/service-account-credential"
);
setPopup({
message: `Successfully deleted ${
localServiceAccountData
? "service account key"
: "app credentials"
}`,
type: "success",
});
// Immediately update local state
if (localServiceAccountData) {
setLocalServiceAccountData(undefined);
} else {
setLocalAppCredentialData(undefined);
}
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete credentials - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete Credentials
</Button>
</div>
)}
</div>
)}
{!(
localServiceAccountData?.service_account_email ||
localAppCredentialData?.client_id
) && (
<GmailCredentialUpload setPopup={setPopup} onSuccess={handleSuccess} />
)}
here
</a>{" "}
to either (1) setup a Google OAuth App in your company workspace or (2)
create a Service Account.
<br />
<br />
Download the credentials JSON if choosing option (1) or the Service
Account key JSON if choosing option (2), and upload it here.
</p>
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
</div>
);
};
interface GmailCredentialSectionProps {
interface DriveCredentialSectionProps {
gmailPublicCredential?: Credential<GmailCredentialJson>;
gmailServiceAccountCredential?: Credential<GmailServiceAccountCredentialJson>;
serviceAccountKeyData?: { service_account_email: string };
@@ -459,7 +387,7 @@ export const GmailAuthSection = ({
refreshCredentials,
connectorExists,
user,
}: GmailCredentialSectionProps) => {
}: DriveCredentialSectionProps) => {
const router = useRouter();
const [isAuthenticating, setIsAuthenticating] = useState(false);
const [localServiceAccountData, setLocalServiceAccountData] = useState(
@@ -492,141 +420,104 @@ export const GmailAuthSection = ({
localGmailPublicCredential || localGmailServiceAccountCredential;
if (existingCredential) {
return (
<div>
<div className="mt-4">
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
<div className="flex-1">
<span className="font-medium block">Authentication Complete</span>
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
Your Gmail credentials have been successfully uploaded and
authenticated.
</p>
</div>
</div>
<Button
variant="destructive"
type="button"
onClick={async () => {
handleRevokeAccess(
connectorExists,
setPopup,
existingCredential,
refreshCredentials
);
}}
>
Revoke Access
</Button>
</div>
</div>
);
}
// If no credentials are uploaded, show message to complete step 1 first
if (
!localServiceAccountData?.service_account_email &&
!localAppCredentialData?.client_id
) {
return (
<div>
<SectionHeader>Gmail Authentication</SectionHeader>
<div className="mt-4">
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
<p className="text-sm">
Please complete Step 1 by uploading either OAuth credentials or a
Service Account key before proceeding with authentication.
</p>
</div>
</div>
</div>
<>
<p className="mb-2 text-sm">
<i>Uploaded and authenticated credential already exists!</i>
</p>
<Button
onClick={async () => {
handleRevokeAccess(
connectorExists,
setPopup,
existingCredential,
refreshCredentials
);
}}
>
Revoke Access
</Button>
</>
);
}
if (localServiceAccountData?.service_account_email) {
return (
<div>
<div className="mt-4">
<Formik
initialValues={{
google_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_primary_admin: Yup.string()
.email("Must be a valid email")
.required("Required"),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
const response = await fetch(
"/api/manage/admin/connector/gmail/service-account-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
google_primary_admin: values.google_primary_admin,
}),
}
);
if (response.ok) {
setPopup({
message: "Successfully created service account credential",
type: "success",
});
refreshCredentials();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to create service account credential - ${errorMsg}`,
type: "error",
});
<Formik
initialValues={{
google_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_primary_admin: Yup.string()
.email("Must be a valid email")
.required("Required"),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
const response = await fetch(
"/api/manage/admin/connector/gmail/service-account-credential",
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
google_primary_admin: values.google_primary_admin,
}),
}
} catch (error) {
);
if (response.ok) {
setPopup({
message: `Failed to create service account credential - ${error}`,
message: "Successfully created service account credential",
type: "success",
});
refreshCredentials();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to create service account credential - ${errorMsg}`,
type: "error",
});
} finally {
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_primary_admin"
label="Primary Admin Email:"
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
/>
<div className="flex">
<Button type="submit" disabled={isSubmitting}>
{isSubmitting ? "Creating..." : "Create Credential"}
</Button>
</div>
</Form>
)}
</Formik>
</div>
} catch (error) {
setPopup({
message: `Failed to create service account credential - ${error}`,
type: "error",
});
} finally {
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_primary_admin"
label="Primary Admin Email:"
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
/>
<div className="flex">
<Button type="submit" disabled={isSubmitting}>
Create Credential
</Button>
</div>
</Form>
)}
</Formik>
</div>
);
}
if (localAppCredentialData?.client_id) {
return (
<div>
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
<p className="text-sm">
Next, you need to authenticate with Gmail via OAuth. This gives us
read access to the emails you have access to in your Gmail account.
</p>
</div>
<div className="text-sm mb-4">
<p className="mb-2">
Next, you must provide credentials via OAuth. This gives us read
access to the emails you have access to in your Gmail account.
</p>
<Button
disabled={isAuthenticating}
onClick={async () => {
setIsAuthenticating(true);
try {
@@ -654,6 +545,7 @@ export const GmailAuthSection = ({
setIsAuthenticating(false);
}
}}
disabled={isAuthenticating}
>
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
</Button>
@@ -661,6 +553,11 @@ export const GmailAuthSection = ({
);
}
// This code path should not be reached with the new conditions above
return null;
// case where no keys have been uploaded in step 1
return (
<p className="text-sm">
Please upload either a OAuth Client Credential JSON or a Gmail Service
Account Key JSON in Step 1 before moving onto Step 2.
</p>
);
};

View File

@@ -173,9 +173,6 @@ export const GmailMain = () => {
serviceAccountCredentialData={serviceAccountKeyData}
isAdmin={isAdmin}
onSuccess={handleRefresh}
existingAuthCredential={Boolean(
gmailPublicUploadedCredential || gmailServiceAccountCredential
)}
/>
{isAdmin && hasUploadedCredentials && (

View File

@@ -114,8 +114,8 @@ function Main() {
<ul className="list-disc mt-2 ml-4 mb-2">
<li>
<Text>
Set a global rate limit to control your team&apos;s overall token
spend.
Set a global rate limit to control your organization&apos;s overall
token spend.
</Text>
</li>
{isPaidEnterpriseFeaturesEnabled && (

View File

@@ -2,7 +2,14 @@
import { useState, useEffect, useCallback } from "react";
import { useRouter } from "next/navigation";
import { Formik, Form, Field, ErrorMessage, FieldArray } from "formik";
import {
Formik,
Form,
Field,
ErrorMessage,
FieldArray,
ArrayHelpers,
} from "formik";
import * as Yup from "yup";
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
@@ -42,7 +49,7 @@ function prettifyDefinition(definition: any) {
return JSON.stringify(definition, null, 2);
}
function ActionForm({
function ToolForm({
existingTool,
values,
setFieldValue,
@@ -111,7 +118,7 @@ function ActionForm({
<TextFormField
name="definition"
label="Definition"
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this action."
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this tool."
placeholder="Enter your OpenAPI schema here"
isTextArea={true}
defaultHeight="h-96"
@@ -178,7 +185,7 @@ function ActionForm({
clipRule="evenodd"
/>
</svg>
Learn more about actions in our documentation
Learn more about tool calling in our documentation
</Link>
</div>
@@ -222,7 +229,7 @@ function ActionForm({
Custom Headers
</h3>
<p className="text-sm mb-6 text-text-600 italic">
Specify custom headers for each request to this action&apos;s API.
Specify custom headers for each request to this tool&apos;s API.
</p>
<FieldArray
name="customHeaders"
@@ -353,7 +360,7 @@ function ActionForm({
type="submit"
disabled={isSubmitting || !!definitionError}
>
{existingTool ? "Update Action" : "Create Action"}
{existingTool ? "Update Tool" : "Create Tool"}
</Button>
</div>
</Form>
@@ -379,7 +386,7 @@ const ToolSchema = Yup.object().shape({
passthrough_auth: Yup.boolean().default(false),
});
export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const [definitionError, setDefinitionError] = useState<string | null>(null);
@@ -425,7 +432,7 @@ export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
try {
definition = parseJsonWithTrailingCommas(values.definition);
} catch (error) {
setDefinitionError("Invalid JSON in action definition");
setDefinitionError("Invalid JSON in tool definition");
return;
}
@@ -446,17 +453,17 @@ export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
}
if (response.error) {
setPopup({
message: "Failed to create action - " + response.error,
message: "Failed to create tool - " + response.error,
type: "error",
});
return;
}
router.push(`/admin/actions?u=${Date.now()}`);
router.push(`/admin/tools?u=${Date.now()}`);
}}
>
{({ isSubmitting, values, setFieldValue }) => {
return (
<ActionForm
<ToolForm
existingTool={tool}
values={values}
setFieldValue={setFieldValue}

View File

@@ -15,7 +15,7 @@ import { TrashIcon } from "@/components/icons/icons";
import { deleteCustomTool } from "@/lib/tools/edit";
import { TableHeader } from "@/components/ui/table";
export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) {
const router = useRouter();
const { popup, setPopup } = usePopup();

View File

@@ -2,7 +2,7 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import CardSection from "@/components/admin/CardSection";
import { ActionEditor } from "@/app/admin/actions/ActionEditor";
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
import { fetchToolByIdSS } from "@/lib/tools/fetchTools";
import { DeleteToolButton } from "./DeleteToolButton";
import { AdminPageTitle } from "@/components/admin/Title";
@@ -31,7 +31,7 @@ export default async function Page(props: {
<div>
<div>
<CardSection>
<ActionEditor tool={tool} />
<ToolEditor tool={tool} />
</CardSection>
<Title className="mt-12">Delete Tool</Title>

View File

@@ -1,6 +1,6 @@
"use client";
import { ActionEditor } from "@/app/admin/actions/ActionEditor";
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
import { BackButton } from "@/components/BackButton";
import { AdminPageTitle } from "@/components/admin/Title";
import { ToolIcon } from "@/components/icons/icons";
@@ -12,12 +12,12 @@ export default function NewToolPage() {
<BackButton />
<AdminPageTitle
title="Create Action"
title="Create Tool"
icon={<ToolIcon size={32} className="my-auto" />}
/>
<CardSection>
<ActionEditor />
<ToolEditor />
</CardSection>
</div>
);

View File

@@ -1,4 +1,4 @@
import { ActionsTable } from "./ActionTable";
import { ToolsTable } from "./ToolsTable";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { FiPlusSquare } from "react-icons/fi";
import Link from "next/link";
@@ -29,23 +29,23 @@ export default async function Page() {
<div className="mx-auto container">
<AdminPageTitle
icon={<ToolIcon size={32} className="my-auto" />}
title="Actions"
title="Tools"
/>
<Text className="mb-2">
Actions allow assistants to retrieve information or take actions.
Tools allow assistants to retrieve information or take actions.
</Text>
<div>
<Separator />
<Title>Create an Action</Title>
<CreateButton href="/admin/actions/new" text="New Action" />
<Title>Create a Tool</Title>
<CreateButton href="/admin/tools/new" text="New Tool" />
<Separator />
<Title>Existing Actions</Title>
<ActionsTable tools={tools} />
<Title>Existing Tools</Title>
<ToolsTable tools={tools} />
</div>
</div>
);

View File

@@ -21,8 +21,7 @@ import { InvitedUserSnapshot } from "@/lib/types";
import { SearchBar } from "@/components/search/SearchBar";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import PendingUsersTable from "@/components/admin/users/PendingUsersTable";
import { useUser } from "@/components/user/UserProvider";
const UsersTables = ({
q,
setPopup,
@@ -45,15 +44,6 @@ const UsersTables = ({
errorHandlingFetcher
);
const {
data: pendingUsers,
error: pendingUsersError,
isLoading: pendingUsersLoading,
mutate: pendingUsersMutate,
} = useSWR<InvitedUserSnapshot[]>(
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
errorHandlingFetcher
);
// Show loading animation only during the initial data fetch
if (!validDomains) {
return <ThreeDotsLoader />;
@@ -73,9 +63,6 @@ const UsersTables = ({
<TabsList>
<TabsTrigger value="current">Current Users</TabsTrigger>
<TabsTrigger value="invited">Invited Users</TabsTrigger>
{NEXT_PUBLIC_CLOUD_ENABLED && (
<TabsTrigger value="pending">Pending Users</TabsTrigger>
)}
</TabsList>
<TabsContent value="current">
@@ -110,25 +97,6 @@ const UsersTables = ({
</CardContent>
</Card>
</TabsContent>
{NEXT_PUBLIC_CLOUD_ENABLED && (
<TabsContent value="pending">
<Card>
<CardHeader>
<CardTitle>Pending Users</CardTitle>
</CardHeader>
<CardContent>
<PendingUsersTable
users={pendingUsers || []}
setPopup={setPopup}
mutate={pendingUsersMutate}
error={pendingUsersError}
isLoading={pendingUsersLoading}
q={q}
/>
</CardContent>
</Card>
</TabsContent>
)}
</Tabs>
);
};
@@ -222,7 +190,7 @@ const AddUserButton = ({
entityName="your Access Logic"
onClose={() => setShowConfirmation(false)}
onSubmit={handleConfirmFirstInvite}
additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your team."
additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your instance."
actionButtonText="Continue"
variant="action"
/>

View File

@@ -18,8 +18,8 @@ const Page = () => {
need to either:
</p>
<ul className="list-disc text-left text-text-600 w-full pl-6 mx-auto">
<li>Be invited to an existing Onyx team</li>
<li>Create a new Onyx team</li>
<li>Be invited to an existing Onyx organization</li>
<li>Create a new Onyx organization</li>
</ul>
<div className="flex justify-center">
<Link

View File

@@ -1,108 +0,0 @@
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { User } from "@/lib/types";
import {
getCurrentUserSS,
getAuthTypeMetadataSS,
AuthTypeMetadata,
getAuthUrlSS,
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
import Text from "@/components/ui/text";
import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import AuthErrorDisplay from "@/components/auth/AuthErrorDisplay";
const Page = async (props: {
searchParams?: Promise<{ [key: string]: string | string[] | undefined }>;
}) => {
const searchParams = await props.searchParams;
const nextUrl = Array.isArray(searchParams?.next)
? searchParams?.next[0]
: searchParams?.next || null;
const defaultEmail = Array.isArray(searchParams?.email)
? searchParams?.email[0]
: searchParams?.email || null;
const teamName = Array.isArray(searchParams?.team)
? searchParams?.team[0]
: searchParams?.team || "your team";
// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
// will not render
let authTypeMetadata: AuthTypeMetadata | null = null;
let currentUser: User | null = null;
try {
[authTypeMetadata, currentUser] = await Promise.all([
getAuthTypeMetadataSS(),
getCurrentUserSS(),
]);
} catch (e) {
console.log(`Some fetch failed for the login page - ${e}`);
}
// simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") {
return redirect("/chat");
}
// if user is already logged in, take them to the main app page
if (currentUser && currentUser.is_active && !currentUser.is_anonymous_user) {
if (!authTypeMetadata?.requiresVerification || currentUser.is_verified) {
return redirect("/chat");
}
return redirect("/auth/waiting-on-verification");
}
const cloud = authTypeMetadata?.authType === "cloud";
// only enable this page if basic login is enabled
if (authTypeMetadata?.authType !== "basic" && !cloud) {
return redirect("/chat");
}
let authUrl: string | null = null;
if (cloud && authTypeMetadata) {
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
}
const emailDomain = defaultEmail?.split("@")[1];
return (
<AuthFlowContainer authState="join">
<HealthCheckBanner />
<AuthErrorDisplay searchParams={searchParams} />
<>
<div className="absolute top-10x w-full"></div>
<div className="flex w-full flex-col justify-center">
<h2 className="text-center text-xl text-strong font-bold">
Re-authenticate to join team
</h2>
{cloud && authUrl && (
<div className="w-full justify-center">
<SignInButton authorizeUrl={authUrl} authType="cloud" />
<div className="flex items-center w-full my-4">
<div className="flex-grow border-t border-background-300"></div>
<span className="px-4 text-text-500">or</span>
<div className="flex-grow border-t border-background-300"></div>
</div>
</div>
)}
<EmailPasswordForm
isSignup
isJoin
shouldVerify={authTypeMetadata?.requiresVerification}
nextUrl={nextUrl}
defaultEmail={defaultEmail}
/>
</div>
</>
</AuthFlowContainer>
);
};
export default Page;

View File

@@ -13,7 +13,6 @@ import { set } from "lodash";
import { NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED } from "@/lib/constants";
import Link from "next/link";
import { useUser } from "@/components/user/UserProvider";
import { useRouter } from "next/navigation";
export function EmailPasswordForm({
isSignup = false,
@@ -21,18 +20,15 @@ export function EmailPasswordForm({
referralSource,
nextUrl,
defaultEmail,
isJoin = false,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
nextUrl?: string | null;
defaultEmail?: string | null;
isJoin?: boolean;
}) {
const { user } = useUser();
const { popup, setPopup } = usePopup();
const router = useRouter();
const [isWorking, setIsWorking] = useState(false);
return (
<>
@@ -83,11 +79,6 @@ export function EmailPasswordForm({
});
setIsWorking(false);
return;
} else {
setPopup({
type: "success",
message: "Account created successfully. Please log in.",
});
}
}
@@ -101,9 +92,7 @@ export function EmailPasswordForm({
window.location.href = "/auth/waiting-on-verification";
} else {
// See above comment
window.location.href = nextUrl
? encodeURI(nextUrl)
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;
window.location.href = nextUrl ? encodeURI(nextUrl) : "/";
}
} else {
setIsWorking(false);
@@ -146,12 +135,11 @@ export function EmailPasswordForm({
/>
<Button
variant="agent"
type="submit"
disabled={isSubmitting}
className="mx-auto !py-4 w-full"
>
{isJoin ? "Join" : isSignup ? "Sign Up" : "Log In"}
{isSignup ? "Sign Up" : "Log In"}
</Button>
{user?.is_anonymous_user && (
<Link

View File

@@ -51,16 +51,25 @@ export default function LoginPage({
</div>
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
<div className="flex mt-4 justify-between">
<div className="flex mt-4 justify-between">
<Link
href={`/auth/signup${
searchParams?.next ? `?next=${searchParams.next}` : ""
}`}
className="text-link font-medium"
>
Create an account
</Link>
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
<Link
href="/auth/forgot-password"
className="text-link font-medium"
>
Reset Password
</Link>
</div>
)}
)}
</div>
</div>
)}

View File

@@ -46,7 +46,7 @@ export function SignInButton({
return (
<a
className="mx-auto mb-4 mt-6 py-3 w-full dark:text-neutral-300 text-neutral-600 border border-neutral-300 flex rounded cursor-pointer hover:border-neutral-400 transition-colors"
className="mx-auto mb-4 mt-6 py-3 w-full text-neutral-100 bg-indigo-500 flex rounded cursor-pointer hover:bg-indigo-800"
href={finalAuthorizeUrl}
>
{button}

View File

@@ -3,7 +3,7 @@
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { useContext, useState, useRef, useLayoutEffect } from "react";
import { ChevronDownIcon } from "@/components/icons/icons";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
import { MinimalMarkdown } from "@/components/chat/MinimalMarkdown";
export function ChatBanner() {
const settings = useContext(SettingsContext);

View File

@@ -109,6 +109,7 @@ import {
} from "@/components/resizable/constants";
import FixedLogo from "../../components/logo/FixedLogo";
import { MinimalMarkdown } from "@/components/chat/MinimalMarkdown";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import {
@@ -137,7 +138,6 @@ import { useSidebarShortcut } from "@/lib/browserUtilities";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
import { ErrorBanner } from "./message/Resubmit";
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -215,7 +215,11 @@ export function ChatPage({
const isInitialLoad = useRef(true);
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
const { assistants: availableAssistants, pinnedAssistants } = useAssistants();
const {
assistants: availableAssistants,
finalAssistants,
pinnedAssistants,
} = useAssistants();
const [showApiKeyModal, setShowApiKeyModal] = useState(
!shouldShowWelcomeModal
@@ -225,7 +229,7 @@ export function ChatPage({
const slackChatId = searchParams.get("slackChatId");
const existingChatIdRaw = searchParams.get("chatId");
const [showHistorySidebar, setShowHistorySidebar] = useState(false);
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
@@ -2447,7 +2451,7 @@ export function ChatPage({
h-full
${sidebarVisible ? "w-[200px]" : "w-[0px]"}
`}
/>
></div>
)}
</div>
)}

View File

@@ -6,7 +6,6 @@ import { Button } from "@/components/ui/button";
import { useContext, useEffect, useState } from "react";
import ReactMarkdown from "react-markdown";
import remarkGfm from "remark-gfm";
import { transformLinkUri } from "@/lib/utils";
const ALL_USERS_INITIAL_POPUP_FLOW_COMPLETED =
"allUsersInitialPopupFlowCompleted";
@@ -45,26 +44,23 @@ export function ChatPopup() {
return (
<Modal width="w-3/6 xl:w-[700px]" title={popupTitle}>
<>
<div className="overflow-y-auto max-h-[90vh] py-8 px-4 text-left">
<ReactMarkdown
className="prose text-text-800 dark:text-neutral-100 max-w-full"
components={{
a: ({ node, ...props }) => (
<a
{...props}
className="text-link hover:text-link-hover"
target="_blank"
rel="noopener noreferrer"
/>
),
p: ({ node, ...props }) => <p {...props} className="text-sm" />,
}}
remarkPlugins={[remarkGfm]}
urlTransform={transformLinkUri}
>
{popupContent}
</ReactMarkdown>
</div>
<ReactMarkdown
className="prose text-text-800 dark:text-neutral-100 max-w-full"
components={{
a: ({ node, ...props }) => (
<a
{...props}
className="text-link hover:text-link-hover"
target="_blank"
rel="noopener noreferrer"
/>
),
p: ({ node, ...props }) => <p {...props} className="text-sm" />,
}}
remarkPlugins={[remarkGfm]}
>
{popupContent}
</ReactMarkdown>
{showConsentError && (
<p className="text-red-500 text-sm mt-2">

View File

@@ -53,7 +53,6 @@ import { copyAll, handleCopy } from "./copyingUtils";
import { Button } from "@/components/ui/button";
import { RefreshCw } from "lucide-react";
import { ErrorBanner, Resubmit } from "./Resubmit";
import { transformLinkUri } from "@/lib/utils";
export const AgenticMessage = ({
isStreamingQuestions,
@@ -337,7 +336,6 @@ export const AgenticMessage = ({
}}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
urlTransform={transformLinkUri}
>
{finalAlternativeContent}
</ReactMarkdown>
@@ -351,7 +349,6 @@ export const AgenticMessage = ({
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
urlTransform={transformLinkUri}
>
{streamedContent +
(!isComplete && !secondLevelGenerating ? " [*]() " : "")}

View File

@@ -160,9 +160,8 @@ export const MemoizedLink = memo(
const handleMouseDown = () => {
let url = href || rest.children?.toString();
if (url && !url.includes("://")) {
// Only add https:// if the URL doesn't already have a protocol
if (url && !url.startsWith("http://") && !url.startsWith("https://")) {
// Try to construct a valid URL
const httpsUrl = `https://${url}`;
try {
new URL(httpsUrl);

View File

@@ -71,7 +71,6 @@ import remarkMath from "remark-math";
import rehypeKatex from "rehype-katex";
import "katex/dist/katex.min.css";
import { copyAll, handleCopy } from "./copyingUtils";
import { transformLinkUri } from "@/lib/utils";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -349,7 +348,7 @@ export const AIMessage = ({
a: anchorCallback,
p: paragraphCallback,
b: ({ node, className, children }: any) => {
return <span className={className}>{children}</span>;
return <span className={className}>||||{children}</span>;
},
code: ({ node, className, children }: any) => {
const codeText = extractCodeText(
@@ -382,7 +381,6 @@ export const AIMessage = ({
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
urlTransform={transformLinkUri}
>
{finalContent}
</ReactMarkdown>

View File

@@ -16,15 +16,15 @@ import ReactMarkdown from "react-markdown";
import { MemoizedAnchor } from "./MemoizedTextComponents";
import { MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText, preprocessLaTeX } from "./codeUtils";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import rehypeKatex from "rehype-katex";
import remarkGfm from "remark-gfm";
import { CodeBlock } from "./CodeBlock";
import { CheckIcon, ChevronDown } from "lucide-react";
import { PHASE_MIN_MS, useStreamingMessages } from "./StreamingMessages";
import { CirclingArrowIcon } from "@/components/icons/icons";
import { handleCopy } from "./copyingUtils";
import { transformLinkUri } from "@/lib/utils";
export const StatusIndicator = ({ status }: { status: ToggleState }) => {
return (
@@ -301,7 +301,6 @@ const SubQuestionDisplay: React.FC<{
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeKatex]}
urlTransform={transformLinkUri}
>
{finalContent}
</ReactMarkdown>

View File

@@ -62,13 +62,19 @@ export function extractCodeText(
// We must preprocess LaTeX in the LLM output to avoid improper formatting
export const preprocessLaTeX = (content: string) => {
// 1) Replace block-level LaTeX delimiters \[ \] with $$ $$
const blockProcessedContent = content.replace(
// 1) Escape dollar signs used outside of LaTeX context
const escapedCurrencyContent = content.replace(
/\$(\d+(?:\.\d*)?)/g,
(_, p1) => `\\$${p1}`
);
// 2) Replace block-level LaTeX delimiters \[ \] with $$ $$
const blockProcessedContent = escapedCurrencyContent.replace(
/\\\[([\s\S]*?)\\\]/g,
(_, equation) => `$$${equation}$$`
);
// 2) Replace inline LaTeX delimiters \( \) with $ $
// 3) Replace inline LaTeX delimiters \( \) with $ $
const inlineProcessedContent = blockProcessedContent.replace(
/\\\(([\s\S]*?)\\\)/g,
(_, equation) => `$${equation}$`
@@ -76,3 +82,223 @@ export const preprocessLaTeX = (content: string) => {
return inlineProcessedContent;
};
interface MarkdownSegment {
type: "text" | "link" | "code" | "bold" | "italic" | "codeblock";
text: string; // The visible/plain text
raw: string; // The raw markdown including syntax
length: number; // Length of the visible text
}
export function parseMarkdownToSegments(markdown: string): MarkdownSegment[] {
if (!markdown) {
return [];
}
const segments: MarkdownSegment[] = [];
let currentIndex = 0;
const maxIterations = markdown.length * 2; // Prevent infinite loops
let iterations = 0;
while (currentIndex < markdown.length && iterations < maxIterations) {
iterations++;
let matched = false;
// Check for code blocks first (they take precedence)
const codeBlockMatch = markdown
.slice(currentIndex)
.match(/^```(\w*)\n([\s\S]*?)```/);
if (codeBlockMatch && codeBlockMatch[0]) {
const [fullMatch, , code] = codeBlockMatch;
segments.push({
type: "codeblock",
text: code || "",
raw: fullMatch,
length: (code || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for inline code
const inlineCodeMatch = markdown.slice(currentIndex).match(/^`([^`]+)`/);
if (inlineCodeMatch && inlineCodeMatch[0]) {
const [fullMatch, code] = inlineCodeMatch;
segments.push({
type: "code",
text: code || "",
raw: fullMatch,
length: (code || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for links
const linkMatch = markdown
.slice(currentIndex)
.match(/^\[([^\]]+)\]\(([^)]+)\)/);
if (linkMatch && linkMatch[0]) {
const [fullMatch, text] = linkMatch;
segments.push({
type: "link",
text: text || "",
raw: fullMatch,
length: (text || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for bold
const boldMatch = markdown
.slice(currentIndex)
.match(/^(\*\*|__)([^*_\n]*?)\1/);
if (boldMatch && boldMatch[0]) {
const [fullMatch, , text] = boldMatch;
segments.push({
type: "bold",
text: text || "",
raw: fullMatch,
length: (text || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for italic
const italicMatch = markdown
.slice(currentIndex)
.match(/^(\*|_)([^*_\n]+?)\1(?!\*|_)/);
if (italicMatch && italicMatch[0]) {
const [fullMatch, , text] = italicMatch;
segments.push({
type: "italic",
text: text || "",
raw: fullMatch,
length: (text || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// If no matches were found, handle regular text
if (!matched) {
let nextSpecialChar = markdown.slice(currentIndex).search(/[`\[*_]/);
if (nextSpecialChar === -1) {
// No more special characters, add the rest as text
const text = markdown.slice(currentIndex);
if (text) {
segments.push({
type: "text",
text: text,
raw: text,
length: text.length,
});
}
break;
} else {
// Add the text up to the next special character
const text = markdown.slice(
currentIndex,
currentIndex + nextSpecialChar
);
if (text) {
segments.push({
type: "text",
text: text,
raw: text,
length: text.length,
});
}
currentIndex += nextSpecialChar;
}
}
}
return segments;
}
export function getMarkdownForSelection(
content: string,
selectedText: string
): string {
const segments = parseMarkdownToSegments(content);
// Build plain text and create mapping to markdown segments
let plainText = "";
const markdownPieces: string[] = [];
let currentPlainIndex = 0;
segments.forEach((segment) => {
plainText += segment.text;
markdownPieces.push(segment.raw);
currentPlainIndex += segment.length;
});
// Find the selection in the plain text
const startIndex = plainText.indexOf(selectedText);
if (startIndex === -1) {
return selectedText;
}
const endIndex = startIndex + selectedText.length;
// Find which segments the selection spans
let currentIndex = 0;
let result = "";
let selectionStart = startIndex;
let selectionEnd = endIndex;
segments.forEach((segment) => {
const segmentStart = currentIndex;
const segmentEnd = segmentStart + segment.length;
// Check if this segment overlaps with the selection
if (segmentEnd > selectionStart && segmentStart < selectionEnd) {
// Calculate how much of this segment to include
const overlapStart = Math.max(0, selectionStart - segmentStart);
const overlapEnd = Math.min(segment.length, selectionEnd - segmentStart);
if (segment.type === "text") {
const textPortion = segment.text.slice(overlapStart, overlapEnd);
result += textPortion;
} else {
// For markdown elements, wrap just the selected portion with the appropriate markdown
const selectedPortion = segment.text.slice(overlapStart, overlapEnd);
switch (segment.type) {
case "bold":
result += `**${selectedPortion}**`;
break;
case "italic":
result += `*${selectedPortion}*`;
break;
case "code":
result += `\`${selectedPortion}\``;
break;
case "link":
// For links, we need to preserve the URL if it exists in the raw markdown
const urlMatch = segment.raw.match(/\]\((.*?)\)/);
const url = urlMatch ? urlMatch[1] : "";
result += `[${selectedPortion}](${url})`;
break;
case "codeblock":
result += `\`\`\`\n${selectedPortion}\n\`\`\``;
break;
default:
result += selectedPortion;
}
}
}
currentIndex += segment.length;
});
return result;
}

View File

@@ -117,8 +117,9 @@ export function ShareChatSessionModal({
{shareLink ? (
<div>
<Text>
This chat session is currently shared. Anyone in your team can
view the message history using the following link:
This chat session is currently shared. Anyone in your
organization can view the message history using the following
link:
</Text>
<div className="flex mt-2">
@@ -159,7 +160,7 @@ export function ShareChatSessionModal({
<div>
<Callout type="warning" title="Warning" className="mb-4">
Please make sure that all content in this chat is safe to
share with the whole team.
share with the whole organization.
</Callout>
<div className="flex w-full justify-between">
<Button

View File

@@ -12,7 +12,7 @@ function Main() {
<div className="mt-4">
<Callout type="danger" title="Custom Analytics is not enabled.">
To set up custom analytics scripts, please work with the team who
setup Onyx in your team to set the{" "}
setup Onyx in your organization to set the{" "}
<i>CUSTOM_ANALYTICS_SECRET_KEY</i> environment variable.
</Callout>
</div>

View File

@@ -140,7 +140,7 @@ export function WhitelabelingForm() {
<TextFormField
label="Application Name"
name="application_name"
subtext={`The custom name you are giving Onyx for your team. This will replace 'Onyx' everywhere in the UI.`}
subtext={`The custom name you are giving Onyx for your organization. This will replace 'Onyx' everywhere in the UI.`}
placeholder="Custom name which will replace 'Onyx'"
disabled={isSubmitting}
/>

View File

@@ -202,10 +202,10 @@ export function ClientLayout({
className="text-text-700"
size={18}
/>
<div className="ml-1">Actions</div>
<div className="ml-1">Tools</div>
</div>
),
link: "/admin/actions",
link: "/admin/tools",
},
]
: []),

View File

@@ -2,8 +2,19 @@
"use client";
import React, { useContext } from "react";
import Link from "next/link";
import { Logo } from "@/components/logo/Logo";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { WarningCircle, WarningDiamond } from "@phosphor-icons/react";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { CgArrowsExpandUpLeft } from "react-icons/cg";
import LogoWithText from "@/components/header/LogoWithText";
import { LogoComponent } from "@/components/logo/FixedLogo";
interface Item {

View File

@@ -33,8 +33,6 @@ import Link from "next/link";
import { CheckboxField } from "@/components/ui/checkbox";
import { CheckedState } from "@radix-ui/react-checkbox";
import { transformLinkUri } from "@/lib/utils";
export function SectionHeader({
children,
}: {
@@ -434,7 +432,6 @@ export const MarkdownFormField = ({
<ReactMarkdown
className="prose dark:prose-invert"
remarkPlugins={[remarkGfm]}
urlTransform={transformLinkUri}
>
{field.value}
</ReactMarkdown>

View File

@@ -1,154 +0,0 @@
import { useState } from "react";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import {
Table,
TableHead,
TableRow,
TableBody,
TableCell,
} from "@/components/ui/table";
import CenteredPageSelector from "./CenteredPageSelector";
import { ThreeDotsLoader } from "@/components/Loading";
import { InvitedUserSnapshot } from "@/lib/types";
import { TableHeader } from "@/components/ui/table";
import { Button } from "@/components/ui/button";
import { ErrorCallout } from "@/components/ErrorCallout";
import { FetchError } from "@/lib/fetcher";
import { CheckIcon } from "lucide-react";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
const USERS_PER_PAGE = 10;
interface Props {
users: InvitedUserSnapshot[];
setPopup: (spec: PopupSpec) => void;
mutate: () => void;
error: FetchError | null;
isLoading: boolean;
q: string;
}
const PendingUsersTable = ({
users,
setPopup,
mutate,
error,
isLoading,
q,
}: Props) => {
const [currentPageNum, setCurrentPageNum] = useState<number>(1);
const [userToApprove, setUserToApprove] = useState<string | null>(null);
if (!users.length)
return <p>Users that have requested to join will show up here</p>;
const totalPages = Math.ceil(users.length / USERS_PER_PAGE);
// Filter users based on the search query
const filteredUsers = q
? users.filter((user) => user.email.includes(q))
: users;
// Get the current page of users
const currentPageOfUsers = filteredUsers.slice(
(currentPageNum - 1) * USERS_PER_PAGE,
currentPageNum * USERS_PER_PAGE
);
if (isLoading) {
return <ThreeDotsLoader />;
}
if (error) {
return (
<ErrorCallout
errorTitle="Error loading pending users"
errorMsg={error?.info?.detail}
/>
);
}
const handleAcceptRequest = async (email: string) => {
try {
await fetch("/api/tenants/users/invite/approve", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ email }),
});
mutate();
setUserToApprove(null);
} catch (error) {
setPopup({
type: "error",
message: "Failed to approve user request",
});
}
};
return (
<>
{userToApprove && (
<ConfirmEntityModal
entityType="Join Request"
entityName={userToApprove}
onClose={() => setUserToApprove(null)}
onSubmit={() => handleAcceptRequest(userToApprove)}
actionButtonText="Approve"
actionText="approve the join request of"
additionalDetails={`${userToApprove} has requested to join the team. Approving will add them as a user in this team.`}
variant="action"
accent
removeConfirmationText
/>
)}
<Table className="overflow-visible">
<TableHeader>
<TableRow>
<TableHead>Email</TableHead>
<TableHead>
<div className="flex justify-end">Actions</div>
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{currentPageOfUsers.length ? (
currentPageOfUsers.map((user) => (
<TableRow key={user.email}>
<TableCell>{user.email}</TableCell>
<TableCell>
<div className="flex justify-end">
<Button
variant="outline"
size="sm"
onClick={() => setUserToApprove(user.email)}
>
<CheckIcon className="h-4 w-4" />
Accept Join Request
</Button>
</div>
</TableCell>
</TableRow>
))
) : (
<TableRow>
<TableCell colSpan={2} className="h-24 text-center">
{`No pending users found matching "${q}"`}
</TableCell>
</TableRow>
)}
</TableBody>
</Table>
{totalPages > 1 ? (
<CenteredPageSelector
currentPage={currentPageNum}
totalPages={totalPages}
onPageChange={setCurrentPageNum}
/>
) : null}
</>
);
};
export default PendingUsersTable;

View File

@@ -22,19 +22,19 @@ export const LeaveOrganizationButton = ({
}) => {
const router = useRouter();
const { trigger, isMutating } = useSWRMutation(
"/api/tenants/leave-team",
"/api/tenants/leave-organization",
userMutationFetcher,
{
onSuccess: () => {
mutate();
setPopup({
message: "Successfully left the team!",
message: "Successfully left the organization!",
type: "success",
});
},
onError: (errorMsg) =>
setPopup({
message: `Unable to leave team - ${errorMsg}`,
message: `Unable to leave organization - ${errorMsg}`,
type: "error",
}),
}
@@ -53,11 +53,11 @@ export const LeaveOrganizationButton = ({
<ConfirmEntityModal
variant="action"
actionButtonText="Leave"
entityType="team"
entityName="your team"
entityType="organization"
entityName="your organization"
onClose={() => setShowLeaveModal(false)}
onSubmit={handleLeaveOrganization}
additionalDetails="You will lose access to all team data and resources."
additionalDetails="You will lose access to all organization data and resources."
/>
)}

View File

@@ -4,7 +4,7 @@ import { useEffect } from "react";
import { usePopup } from "../admin/connectors/Popup";
const ERROR_MESSAGES = {
Anonymous: "Your team does not have anonymous access enabled.",
Anonymous: "Your organization does not have anonymous access enabled.",
};
export default function AuthErrorDisplay({

View File

@@ -6,7 +6,7 @@ export default function AuthFlowContainer({
authState,
}: {
children: React.ReactNode;
authState?: "signup" | "login" | "join";
authState?: "signup" | "login";
}) {
return (
<div className="p-4 flex flex-col items-center justify-center min-h-screen bg-background">

Some files were not shown because too many files have changed in this diff Show More