mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-24 11:15:47 +00:00
Compare commits
3 Commits
improed_dr
...
memory-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b169350a9 | ||
|
|
c1dbb073d0 | ||
|
|
39bfc6ae16 |
@@ -48,8 +48,6 @@ env:
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -114,4 +114,3 @@ To try the Onyx Enterprise Edition:
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
|
||||
@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Conflicts on lowercasing will result in the uppercased email getting a
|
||||
unique integer suffix when converted to lowercase."""
|
||||
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Fetch all user emails that are not already lowercase
|
||||
user_emails = connection.execute(
|
||||
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
|
||||
).fetchall()
|
||||
|
||||
for user_id, email in user_emails:
|
||||
email = cast(str, email)
|
||||
username, domain = email.rsplit("@", 1)
|
||||
new_email = f"{username.lower()}@{domain.lower()}"
|
||||
attempt = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try updating the email
|
||||
connection.execute(
|
||||
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
|
||||
{"new_email": new_email, "user_id": user_id},
|
||||
)
|
||||
break # Success, exit loop
|
||||
except IntegrityError:
|
||||
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
|
||||
# Email conflict occurred, append `_1`, `_2`, etc., to the username
|
||||
logger.warning(
|
||||
f"Conflict while lowercasing email: "
|
||||
f"old_email={email} "
|
||||
f"conflicting_email={new_email} "
|
||||
f"next_email={next_email}"
|
||||
)
|
||||
new_email = next_email
|
||||
attempt += 1
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add new available tenant table
|
||||
|
||||
Revision ID: 3b45e0018bf1
|
||||
Revises: ac842f85f932
|
||||
Create Date: 2025-03-06 09:55:18.229910
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b45e0018bf1"
|
||||
down_revision = "ac842f85f932"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create new_available_tenant table
|
||||
op.create_table(
|
||||
"available_tenant",
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.Column("alembic_version", sa.String(), nullable=False),
|
||||
sa.Column("date_created", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop new_available_tenant table
|
||||
op.drop_table("available_tenant")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""new column user tenant mapping
|
||||
|
||||
Revision ID: ac842f85f932
|
||||
Revises: 34e3630c7f32
|
||||
Create Date: 2025-03-03 13:30:14.802874
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac842f85f932"
|
||||
down_revision = "34e3630c7f32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add active column with default value of True
|
||||
op.add_column(
|
||||
"user_tenant_mapping",
|
||||
sa.Column(
|
||||
"active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
schema="public",
|
||||
)
|
||||
|
||||
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
|
||||
|
||||
# Create a unique index for active=true records
|
||||
# This ensures a user can only be active in one tenant at a time
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the unique index for active=true records
|
||||
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
|
||||
|
||||
op.create_unique_constraint(
|
||||
"uq_email", "user_tenant_mapping", ["email"], schema="public"
|
||||
)
|
||||
|
||||
# Remove the active column
|
||||
op.drop_column("user_tenant_mapping", "active", schema="public")
|
||||
@@ -1,14 +1,10 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import AgentAnswer
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuery
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
|
||||
# accumulate stream data with these dicts
|
||||
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
|
||||
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
|
||||
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if id in agent_sub_questions:
|
||||
agent_sub_questions[id].document_ids = [
|
||||
saved_search_doc.document_id
|
||||
for saved_search_doc in packet.top_documents
|
||||
]
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
# agentic packets
|
||||
elif isinstance(packet, SubQuestionPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_sub_questions.get(id) is None:
|
||||
agent_sub_questions[id] = AgentSubQuestion(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_question=packet.sub_question,
|
||||
document_ids=[],
|
||||
)
|
||||
else:
|
||||
agent_sub_questions[id].sub_question += packet.sub_question
|
||||
|
||||
elif isinstance(packet, AgentAnswerPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_answers.get(id) is None:
|
||||
agent_answers[id] = AgentAnswer(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
answer=packet.answer_piece,
|
||||
answer_type=packet.answer_type,
|
||||
)
|
||||
else:
|
||||
agent_answers[id].answer += packet.answer_piece
|
||||
elif isinstance(packet, SubQueryPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
sub_query_id = (
|
||||
packet.level,
|
||||
packet.level_question_num,
|
||||
packet.query_id,
|
||||
)
|
||||
if agent_sub_queries.get(sub_query_id) is None:
|
||||
agent_sub_queries[sub_query_id] = AgentSubQuery(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_query=packet.sub_query,
|
||||
query_id=packet.query_id,
|
||||
)
|
||||
else:
|
||||
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
|
||||
elif isinstance(packet, ExtendedToolResponse):
|
||||
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
|
||||
logger.warning(
|
||||
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
|
||||
)
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
response.agent_refined_answer_improvement = (
|
||||
packet.refined_answer_improvement
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
|
||||
)
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
# organize / sort agent metadata for output
|
||||
if len(agent_sub_questions) > 0:
|
||||
response.agent_sub_questions = cast(
|
||||
dict[int, list[AgentSubQuestion]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
|
||||
)
|
||||
|
||||
if len(agent_answers) > 0:
|
||||
# return the agent_level_answer from the first level or the last one depending
|
||||
# on agent_refined_answer_improvement
|
||||
response.agent_answers = cast(
|
||||
dict[int, list[AgentAnswer]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_answers),
|
||||
)
|
||||
if response.agent_answers:
|
||||
selected_answer_level = (
|
||||
0
|
||||
if not response.agent_refined_answer_improvement
|
||||
else len(response.agent_answers) - 1
|
||||
)
|
||||
level_answers = response.agent_answers[selected_answer_level]
|
||||
for level_answer in level_answers:
|
||||
if level_answer.answer_type != "agent_level_answer":
|
||||
continue
|
||||
|
||||
answer = level_answer.answer
|
||||
break
|
||||
|
||||
if len(agent_sub_queries) > 0:
|
||||
# subqueries are often emitted with trailing whitespace ... clean it up here
|
||||
# perhaps fix at the source?
|
||||
for v in agent_sub_queries.values():
|
||||
v.sub_query = v.sub_query.strip()
|
||||
|
||||
response.agent_sub_queries = (
|
||||
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
|
||||
agent_answers: dict[int, list[AgentAnswer]] | None = None
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
@@ -1,98 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
@@ -1,24 +1,269 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.user_invitations_api import (
|
||||
router as user_invitations_router,
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Create a main router to include all sub-routers
|
||||
# Note: We don't add a prefix here as each router already has the /tenants prefix
|
||||
router = APIRouter()
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Include all the individual routers
|
||||
router.include_router(admin_router)
|
||||
router.include_router(anonymous_users_router)
|
||||
router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
creator_email: str
|
||||
|
||||
|
||||
class TenantByDomainRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class RequestInviteRequest(BaseModel):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class RequestInviteResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PendingUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@@ -4,7 +4,6 @@ import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
@@ -28,12 +26,11 @@ from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
@@ -63,72 +60,42 @@ async def get_or_provision_tenant(
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
# Early return for non-multi-tenant mode
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
# First, check if the user already has a tenant
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
return tenant_id
|
||||
except exceptions.UserNotExists:
|
||||
# User doesn't exist, so we need to create a new tenant or assign an existing one
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try to get a pre-provisioned tenant
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
return tenant_id
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
error_msg = "Failed to provision tenant"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to provision tenant. Please try again later.",
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
"""
|
||||
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||
This is the fallback method when we can't use a pre-provisioned tenant.
|
||||
|
||||
"""
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
logger.info(f"Creating new tenant {tenant_id} for user {email}")
|
||||
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
try:
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
@@ -142,25 +109,54 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
)
|
||||
|
||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
# Create the schema for the tenant
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
# Set up the tenant with all necessary configurations
|
||||
await setup_tenant(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Assign the tenant to the user
|
||||
await assign_tenant_to_user(tenant_id, email)
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(
|
||||
@@ -191,74 +187,20 @@ async def notify_control_plane(
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
"""
|
||||
Logic to rollback tenant provisioning on data plane.
|
||||
Handles each step independently to ensure maximum cleanup even if some steps fail.
|
||||
"""
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
|
||||
# Track if any part of the rollback fails
|
||||
rollback_errors = []
|
||||
|
||||
# 1. Try to drop the tenant's schema
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 2. Try to remove tenant mapping
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Successfully removed user mappings for tenant {tenant_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 3. If this tenant was in the available tenants table, remove it
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter(AvailableTenant.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Removed tenant {tenant_id} from available tenants table"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# Log summary of rollback operation
|
||||
if rollback_errors:
|
||||
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
|
||||
else:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
@@ -411,155 +353,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_by_domain_from_control_plane(
|
||||
domain: str,
|
||||
tenant_id: str,
|
||||
) -> TenantByDomainResponse | None:
|
||||
"""
|
||||
Fetches tenant information from the control plane based on the email domain.
|
||||
|
||||
Args:
|
||||
domain: The email domain to search for (e.g., "example.com")
|
||||
|
||||
Returns:
|
||||
A dictionary containing tenant information if found, None otherwise
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
|
||||
headers=headers,
|
||||
json={"domain": domain, "tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Control plane tenant lookup failed: {response.text}")
|
||||
return None
|
||||
|
||||
response_data = response.json()
|
||||
if not response_data:
|
||||
return None
|
||||
|
||||
return TenantByDomainResponse(
|
||||
tenant_id=response_data.get("tenant_id"),
|
||||
number_of_users=response_data.get("number_of_users"),
|
||||
creator_email=response_data.get("creator_email"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_available_tenant() -> str | None:
|
||||
"""
|
||||
Get an available pre-provisioned tenant from the NewAvailableTenant table.
|
||||
Returns the tenant_id if one is available, None otherwise.
|
||||
Uses row-level locking to prevent race conditions when multiple processes
|
||||
try to get an available tenant simultaneously.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
db_session.begin()
|
||||
|
||||
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.order_by(AvailableTenant.date_created)
|
||||
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
tenant_id = available_tenant.tenant_id
|
||||
# Remove the tenant from the available tenants table
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(f"Using pre-provisioned tenant {tenant_id}")
|
||||
return tenant_id
|
||||
else:
|
||||
db_session.rollback()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error getting available tenant")
|
||||
db_session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
async def setup_tenant(tenant_id: str) -> None:
|
||||
"""
|
||||
Set up a tenant with all necessary configurations.
|
||||
This is a centralized function that handles all tenant setup logic.
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Configure default API keys
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
# Set up Onyx with appropriate settings
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to set up tenant {tenant_id}")
|
||||
raise e
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def assign_tenant_to_user(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Assign a tenant to a user and perform necessary operations.
|
||||
Uses transaction handling to ensure atomicity and includes retry logic
|
||||
for control plane notifications.
|
||||
"""
|
||||
# First, add the user to the tenant in a transaction
|
||||
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
|
||||
|
||||
def get_current_alembic_version(tenant_id: str) -> str:
|
||||
"""Get the current Alembic version for a tenant."""
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import text
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Set the search path to the tenant's schema
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
|
||||
# Get the current version from the alembic_version table
|
||||
context = MigrationContext.configure(connection)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
return current_rev or "head"
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
@@ -1,39 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
"gmail",
|
||||
"outlook",
|
||||
"yahoo",
|
||||
"hotmail",
|
||||
"icloud",
|
||||
"msn",
|
||||
"hotmail",
|
||||
"hotmail.co.uk",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
if not user:
|
||||
return None
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
return None
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return get_tenant_by_domain_from_control_plane(domain, tenant_id)
|
||||
@@ -1,90 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.models import ApproveUserRequest
|
||||
from ee.onyx.server.tenants.models import PendingUserSnapshot
|
||||
from ee.onyx.server.tenants.models import RequestInviteRequest
|
||||
from ee.onyx.server.tenants.user_mapping import accept_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
|
||||
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
|
||||
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
accept_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to accept invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to accept invitation")
|
||||
|
||||
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deny_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to deny invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to deny invitation")
|
||||
@@ -1,56 +1,27 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.invited_users import write_pending_users
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# First try to get an active tenant
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
tenant_id = mapping.tenant_id if mapping else None
|
||||
|
||||
# If no active tenant found, try to get the first inactive one
|
||||
if tenant_id is None:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
if mapping:
|
||||
# Mark this mapping as active
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
@@ -67,39 +38,13 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
raise
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
@@ -131,187 +76,3 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
return
|
||||
write_pending_users(pending_users + [email])
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Approve a user invite to a tenant.
|
||||
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete all existing records for this email
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email
|
||||
).delete()
|
||||
|
||||
# Create a new mapping entry for the user in this tenant
|
||||
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_pending_users(pending_users)
|
||||
|
||||
# Add to invited users
|
||||
invited_users = get_invited_users()
|
||||
if email not in invited_users:
|
||||
invited_users.append(email)
|
||||
write_invited_users(invited_users)
|
||||
|
||||
|
||||
def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
logger.info(
|
||||
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Find the inactive mapping for this user and tenant
|
||||
mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# Set all other mappings for this user to inactive
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
).update({"active": False})
|
||||
|
||||
# Activate this mapping
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
This removes the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete the mapping for this user and tenant
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
if result:
|
||||
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_invited_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_invited_users(pending_users)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return user_count
|
||||
|
||||
|
||||
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
|
||||
"""
|
||||
Get the first tenant invitation for this user
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Get the first tenant invitation for this user
|
||||
invitation = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if invitation:
|
||||
# Get the user count for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == invitation.tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return TenantSnapshot(
|
||||
tenant_id=invitation.tenant_id, number_of_users=user_count
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -62,60 +62,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
@@ -146,17 +92,31 @@ class CloudEmbedding:
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {self.provider} "
|
||||
f"Exception: {e}"
|
||||
)
|
||||
logger.error(error_string)
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
# only log text when it's not an authentication error.
|
||||
if not isinstance(e, openai.AuthenticationError):
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
@@ -195,6 +155,7 @@ class CloudEmbedding:
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
@@ -278,51 +239,22 @@ class CloudEmbedding:
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e, str(self.provider), model_name or deployment_name, self.provider
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -637,13 +569,6 @@ async def process_embed_request(
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except AuthenticationError as e:
|
||||
# Handle authentication errors consistently
|
||||
logger.error(f"Authentication error: {e.provider}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Authentication failed: {e.message}",
|
||||
)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
|
||||
@@ -153,8 +153,7 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_PENDING_USERS_KEY
|
||||
from onyx.configs.constants import KV_USER_STORE_KEY
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
|
||||
def get_pending_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
return cast(list, store.load(KV_PENDING_USERS_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_pending_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
@@ -100,7 +100,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
@@ -895,7 +894,7 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accessible_user(
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -1096,12 +1095,6 @@ def get_oauth_router(
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
@@ -1133,14 +1126,9 @@ def get_oauth_router(
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_url = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
@@ -1152,7 +1140,6 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -112,6 +112,5 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -92,6 +92,5 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -5,53 +5,40 @@ from logging.handlers import RotatingFileHandler
|
||||
|
||||
import psutil
|
||||
|
||||
from onyx.utils.logger import is_running_in_container
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# Regular application logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Only set up memory monitoring in container environment
|
||||
if is_running_in_container():
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
|
||||
# Ensure log directory exists
|
||||
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
|
||||
|
||||
# Create a dedicated logger for memory monitoring
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.setLevel(logging.INFO)
|
||||
# Create a dedicated logger for memory monitoring
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.setLevel(logging.INFO)
|
||||
|
||||
# Create a rotating file handler
|
||||
memory_handler = RotatingFileHandler(
|
||||
MEMORY_LOG_FILE,
|
||||
maxBytes=MEMORY_LOG_MAX_BYTES,
|
||||
backupCount=MEMORY_LOG_BACKUP_COUNT,
|
||||
)
|
||||
# Create a rotating file handler
|
||||
memory_handler = RotatingFileHandler(
|
||||
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
|
||||
)
|
||||
|
||||
# Create a formatter that includes all relevant information
|
||||
memory_formatter = logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
memory_handler.setFormatter(memory_formatter)
|
||||
memory_logger.addHandler(memory_handler)
|
||||
else:
|
||||
# Create a null logger when not in container
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.addHandler(logging.NullHandler())
|
||||
# Create a formatter that includes all relevant information
|
||||
memory_formatter = logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
memory_handler.setFormatter(memory_formatter)
|
||||
memory_logger.addHandler(memory_handler)
|
||||
|
||||
|
||||
def emit_process_memory(
|
||||
pid: int, process_name: str, additional_metadata: dict[str, str | int]
|
||||
) -> None:
|
||||
# Skip memory monitoring if not in container
|
||||
if not is_running_in_container():
|
||||
return
|
||||
|
||||
try:
|
||||
process = psutil.Process(pid)
|
||||
memory_info = process.memory_info()
|
||||
|
||||
@@ -167,16 +167,6 @@ beat_cloud_tasks: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that only run self hosted
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
"""
|
||||
Periodic tasks for tenant pre-provisioning.
|
||||
"""
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_available_tenants(self: Task) -> None:
|
||||
"""
|
||||
Check if we have enough pre-provisioned tenants available.
|
||||
If not, trigger the pre-provisioning of new tenants.
|
||||
"""
|
||||
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
|
||||
if not MULTI_TENANT:
|
||||
task_logger.info(
|
||||
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
|
||||
)
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
if not lock_check.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
"Skipping check_available_tenants task because it is already running"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the current count of available tenants
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
available_tenants_count = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
target_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
tenants_to_provision = max(
|
||||
0, target_available_tenants - available_tenants_count
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Available tenants: {available_tenants_count}, "
|
||||
f"Target: {target_available_tenants}, "
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
# Trigger pre-provisioning tasks for each tenant needed
|
||||
for _ in range(tenants_to_provision):
|
||||
from celery import current_app
|
||||
|
||||
current_app.send_task(
|
||||
OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
lock_check.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def pre_provision_tenant(self: Task) -> None:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
so it's ready to be assigned to a user immediately.
|
||||
"""
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
|
||||
r = get_redis_client()
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
if not lock_provision.acquire(blocking=False):
|
||||
task_logger.debug(
|
||||
"Skipping pre_provision_tenant task because it is already running"
|
||||
)
|
||||
return
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
# Generate a new tenant ID
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
|
||||
|
||||
# Create the schema for the new tenant
|
||||
schema_created = create_schema_if_not_exists(tenant_id)
|
||||
if schema_created:
|
||||
task_logger.debug(f"Created schema for tenant: {tenant_id}")
|
||||
else:
|
||||
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
|
||||
|
||||
# Set up the tenant with all necessary configurations
|
||||
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
|
||||
asyncio.run(setup_tenant(tenant_id))
|
||||
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
|
||||
|
||||
# Get the current Alembic version
|
||||
alembic_version = get_current_alembic_version(tenant_id)
|
||||
task_logger.debug(
|
||||
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
|
||||
)
|
||||
|
||||
# Store the pre-provisioned tenant in the database
|
||||
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Use a transaction to ensure atomicity
|
||||
db_session.begin()
|
||||
try:
|
||||
new_tenant = AvailableTenant(
|
||||
tenant_id=tenant_id,
|
||||
alembic_version=alembic_version,
|
||||
date_created=datetime.datetime.now(),
|
||||
)
|
||||
db_session.add(new_tenant)
|
||||
db_session.commit()
|
||||
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
task_logger.error(
|
||||
f"Failed to store pre-provisioned tenant: {tenant_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
|
||||
# If we have a tenant_id, attempt to rollback any partially completed provisioning
|
||||
if tenant_id:
|
||||
task_logger.info(
|
||||
f"Rolling back failed tenant provisioning for: {tenant_id}"
|
||||
)
|
||||
try:
|
||||
from ee.onyx.server.tenants.provisioning import (
|
||||
rollback_tenant_provisioning,
|
||||
)
|
||||
|
||||
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
lock_provision.release()
|
||||
@@ -1,13 +1,10 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
@@ -47,44 +44,9 @@ class LlmDoc(BaseModel):
|
||||
|
||||
|
||||
class SubQuestionIdentifier(BaseModel):
|
||||
"""None represents references to objects in the original flow. To our understanding,
|
||||
these will not be None in the packets returned from agent search.
|
||||
"""
|
||||
|
||||
level: int | None = None
|
||||
level_question_num: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level(
|
||||
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
|
||||
) -> dict[int, list["SubQuestionIdentifier"]]:
|
||||
"""returns a dict of level to object list (sorted by level_question_num)
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
|
||||
# organize by level, then sort ascending by question_index
|
||||
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
|
||||
|
||||
# group by level
|
||||
for k, obj in original_dict.items():
|
||||
level = k[0]
|
||||
if level not in level_dict:
|
||||
level_dict[level] = []
|
||||
level_dict[level].append(obj)
|
||||
|
||||
# for each level, sort the group
|
||||
for k2, value2 in level_dict.items():
|
||||
# we need to handle the none case due to SubQuestionIdentifier typing
|
||||
# level_question_num as int | None, even though it should never be None here.
|
||||
level_dict[k2] = sorted(
|
||||
value2,
|
||||
key=lambda x: (x.level_question_num is None, x.level_question_num),
|
||||
)
|
||||
|
||||
# sort by level
|
||||
sorted_dict = OrderedDict(sorted(level_dict.items()))
|
||||
return sorted_dict
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
|
||||
@@ -374,8 +336,6 @@ class AgentAnswerPiece(SubQuestionIdentifier):
|
||||
|
||||
|
||||
class SubQuestionPiece(SubQuestionIdentifier):
|
||||
"""Refined sub questions generated from the initial user question."""
|
||||
|
||||
sub_question: str
|
||||
|
||||
|
||||
@@ -387,13 +347,13 @@ class RefinedAnswerImprovement(BaseModel):
|
||||
refined_answer_improvement: bool
|
||||
|
||||
|
||||
AgentSearchPacket = Union[
|
||||
AgentSearchPacket = (
|
||||
SubQuestionPiece
|
||||
| AgentAnswerPiece
|
||||
| SubQueryPiece
|
||||
| ExtendedToolResponse
|
||||
| RefinedAnswerImprovement
|
||||
]
|
||||
)
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
|
||||
|
||||
@@ -643,6 +643,3 @@ MOCK_LLM_RESPONSE = (
|
||||
|
||||
|
||||
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
|
||||
# Number of pre-provisioned tenants to maintain
|
||||
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||
|
||||
@@ -76,7 +76,6 @@ KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_PENDING_USERS_KEY = "PENDING_USERS"
|
||||
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
KV_CRED_KEY = "credential_id_{}"
|
||||
KV_GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
@@ -322,8 +321,6 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
@@ -386,7 +383,6 @@ class OnyxCeleryTask:
|
||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||
)
|
||||
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -403,9 +399,6 @@ class OnyxCeleryTask:
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = "pre_provision_tenant"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
|
||||
@@ -263,7 +263,6 @@ class ConfluenceConnector(
|
||||
result = process_attachment(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
page_title,
|
||||
self.image_analysis_llm,
|
||||
)
|
||||
@@ -367,7 +366,6 @@ class ConfluenceConnector(
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
page_context=confluence_xml,
|
||||
llm=self.image_analysis_llm,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -18,11 +19,17 @@ from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -801,6 +808,65 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
|
||||
@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -85,35 +84,25 @@ class AttachmentProcessingResult(BaseModel):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def _make_attachment_link(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
download_link = ""
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
def _download_attachment(
|
||||
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
|
||||
"""
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
resp = confluence_client._session.get(download_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with status code {resp.status_code}"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
return download_link
|
||||
return None
|
||||
return resp.content
|
||||
|
||||
|
||||
def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> AttachmentProcessingResult:
|
||||
@@ -133,52 +122,11 @@ def process_attachment(
|
||||
error=f"Unsupported file type: {media_type}",
|
||||
)
|
||||
|
||||
attachment_link = _make_attachment_link(
|
||||
confluence_client, attachment, parent_content_id
|
||||
)
|
||||
if not attachment_link:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="Failed to make attachment link"
|
||||
)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if not media_type.startswith("image/") or not llm:
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment text too long: {attachment_size} chars",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Downloading attachment: "
|
||||
f"title={attachment['title']} "
|
||||
f"length={attachment_size} "
|
||||
f"link={attachment_link}"
|
||||
)
|
||||
|
||||
# Download the attachment
|
||||
resp: requests.Response = confluence_client._session.get(attachment_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
|
||||
)
|
||||
raw_bytes = _download_attachment(confluence_client, attachment)
|
||||
if raw_bytes is None:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment download status code is {resp.status_code}",
|
||||
)
|
||||
|
||||
raw_bytes = resp.content
|
||||
if not raw_bytes:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="attachment.content is None"
|
||||
text=None, file_name=None, error="Failed to download attachment"
|
||||
)
|
||||
|
||||
# Process image attachments with LLM if available
|
||||
@@ -301,7 +249,6 @@ def _process_text_attachment(
|
||||
def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
@@ -319,9 +266,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(
|
||||
confluence_client, attachment, page_id, page_context, llm
|
||||
)
|
||||
result = process_attachment(confluence_client, attachment, page_context, llm)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
|
||||
@@ -228,15 +228,10 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorMissingCredentialError("GitBook")
|
||||
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content")
|
||||
pages: list[dict[str, Any]] = content.get("pages", [])
|
||||
current_batch: list[Document] = []
|
||||
|
||||
logger.info(f"Found {len(pages)} root pages.")
|
||||
logger.info(
|
||||
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
|
||||
)
|
||||
|
||||
while pages:
|
||||
page = pages.pop(0)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -205,15 +204,6 @@ class ConnectorCheckpoint(BaseModel):
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the checkpoint, with truncation for large checkpoint content."""
|
||||
MAX_CHECKPOINT_CONTENT_CHARS = 1000
|
||||
|
||||
content_str = json.dumps(self.checkpoint_content)
|
||||
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
|
||||
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
|
||||
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
document_id: str
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import fields
|
||||
@@ -31,7 +32,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NOTION_PAGE_SIZE = 100
|
||||
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
|
||||
|
||||
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"""
|
||||
filtered_pages: list[NotionPage] = []
|
||||
for page in pages:
|
||||
# Parse ISO 8601 timestamp and convert to UTC epoch time
|
||||
timestamp = page[filter_field].replace(".000Z", "+00:00")
|
||||
compare_time = datetime.fromisoformat(timestamp).timestamp()
|
||||
compare_time = time.mktime(
|
||||
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
|
||||
)
|
||||
if compare_time > start and compare_time <= end:
|
||||
filtered_pages += [NotionPage(**page)]
|
||||
return filtered_pages
|
||||
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
query_dict = {
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
"page_size": self.batch_size,
|
||||
}
|
||||
while True:
|
||||
db_res = self._search_notion(query_dict)
|
||||
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
"page_size": self.batch_size,
|
||||
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
}
|
||||
|
||||
@@ -48,12 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
domain = "test" if credentials.get("is_sandbox") else None
|
||||
self._sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
domain=domain,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
"""
|
||||
1. Verify the bot token is valid for the workspace (via auth_test).
|
||||
2. Ensure the bot has enough scope to list channels.
|
||||
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
|
||||
3. Check that every channel specified in self.channels exists.
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
@@ -706,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
f"Slack API returned a failure: {error_msg}"
|
||||
)
|
||||
|
||||
# 3) If channels are specified and regex is not enabled, verify each is accessible
|
||||
if self.channels and not self.channel_regex_enabled:
|
||||
# 3) If channels are specified, verify each is accessible
|
||||
if self.channels:
|
||||
accessible_channels = get_channels(
|
||||
client=self.client,
|
||||
exclude_archived=True,
|
||||
|
||||
@@ -2295,31 +2295,21 @@ class PublicBase(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
# Strictly keeps track of the tenant that a given user will authenticate to.
|
||||
class UserTenantMapping(Base):
|
||||
__tablename__ = "user_tenant_mapping"
|
||||
__table_args__ = ({"schema": "public"},)
|
||||
__table_args__ = (
|
||||
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
|
||||
{"schema": "public"},
|
||||
)
|
||||
|
||||
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
class AvailableTenant(Base):
|
||||
__tablename__ = "available_tenant"
|
||||
"""
|
||||
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
|
||||
"""
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
|
||||
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
|
||||
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
__tablename__ = "tenant_anonymous_user_path"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -148,10 +149,11 @@ def delete_document_tags_for_documents__no_commit(
|
||||
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
|
||||
db_session.execute(stmt)
|
||||
|
||||
orphan_tags_query = select(Tag.id).where(
|
||||
~db_session.query(Document__Tag.tag_id)
|
||||
.filter(Document__Tag.tag_id == Tag.id)
|
||||
.exists()
|
||||
orphan_tags_query = (
|
||||
select(Tag.id)
|
||||
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
|
||||
.group_by(Tag.id)
|
||||
.having(func.count(Document__Tag.document_id) == 0)
|
||||
)
|
||||
|
||||
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()
|
||||
|
||||
@@ -234,8 +234,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
yield
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await close_auth_limiter()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -112,7 +112,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_chat_accesssible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
|
||||
@@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
@@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
|
||||
|
||||
@router.get("/connector-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
user: User = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -390,7 +390,7 @@ def get_image_generation_tool(
|
||||
|
||||
@basic_router.get("")
|
||||
def list_personas(
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
include_deleted: bool = False,
|
||||
persona_ids: list[int] = Query(None),
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
@@ -191,7 +191,7 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
|
||||
@@ -53,16 +53,6 @@ class UserPreferences(BaseModel):
|
||||
temperature_override_enabled: bool | None = None
|
||||
|
||||
|
||||
class TenantSnapshot(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
|
||||
class TenantInfo(BaseModel):
|
||||
invitation: TenantSnapshot | None = None
|
||||
new_tenant: TenantSnapshot | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
@@ -75,10 +65,9 @@ class UserInfo(BaseModel):
|
||||
current_token_created_at: datetime | None = None
|
||||
current_token_expiry_length: int | None = None
|
||||
is_cloud_superuser: bool = False
|
||||
team_name: str | None = None
|
||||
organization_name: str | None = None
|
||||
is_anonymous_user: bool | None = None
|
||||
password_configured: bool | None = None
|
||||
tenant_info: TenantInfo | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -87,9 +76,8 @@ class UserInfo(BaseModel):
|
||||
current_token_created_at: datetime | None = None,
|
||||
expiry_length: int | None = None,
|
||||
is_cloud_superuser: bool = False,
|
||||
team_name: str | None = None,
|
||||
organization_name: str | None = None,
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -111,7 +99,7 @@ class UserInfo(BaseModel):
|
||||
temperature_override_enabled=user.temperature_override_enabled,
|
||||
)
|
||||
),
|
||||
team_name=team_name,
|
||||
organization_name=organization_name,
|
||||
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
|
||||
# where they previously had this set + used OIDC, and now they switched to
|
||||
# basic auth are now constantly getting redirected back to the login page
|
||||
@@ -121,7 +109,6 @@ class UserInfo(BaseModel):
|
||||
current_token_expiry_length=expiry_length,
|
||||
is_cloud_superuser=is_cloud_superuser,
|
||||
is_anonymous_user=is_anonymous_user,
|
||||
tenant_info=tenant_info,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,13 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from psycopg2.errors import UniqueViolation
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
@@ -53,8 +55,6 @@ from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import TenantInfo
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.manage.models import UserPreferences
|
||||
@@ -296,6 +296,13 @@ def bulk_invite_users(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="User has already been invited to a Onyx organization",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
|
||||
@@ -418,10 +425,6 @@ async def delete_user(
|
||||
db_session.expunge(user_to_delete)
|
||||
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
logger.info(f"Deleted user {user_to_delete.email}")
|
||||
|
||||
@@ -550,8 +553,8 @@ def verify_user_logged_in(
|
||||
if anonymous_user_enabled(tenant_id=tenant_id):
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store, anonymous_user_enabled=True)
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
@@ -560,35 +563,16 @@ def verify_user_logged_in(
|
||||
token_created_at = (
|
||||
None if MULTI_TENANT else get_current_token_creation(user, db_session)
|
||||
)
|
||||
|
||||
team_name = fetch_ee_implementation_or_noop(
|
||||
organization_name = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(user.email)
|
||||
|
||||
new_tenant: TenantSnapshot | None = None
|
||||
tenant_invitation: TenantSnapshot | None = None
|
||||
|
||||
if MULTI_TENANT:
|
||||
if team_name != get_current_tenant_id():
|
||||
user_count = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_count", None
|
||||
)(team_name)
|
||||
new_tenant = TenantSnapshot(tenant_id=team_name, number_of_users=user_count)
|
||||
|
||||
tenant_invitation = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
|
||||
)(user.email)
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
is_cloud_superuser=user.email in SUPER_USERS,
|
||||
team_name=team_name,
|
||||
tenant_info=TenantInfo(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
organization_name=organization_name,
|
||||
)
|
||||
|
||||
return user_info
|
||||
|
||||
@@ -49,9 +49,9 @@ class FullUserSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class DisplayPriorityRequest(BaseModel):
|
||||
display_priority_map: dict[int, int]
|
||||
|
||||
|
||||
class InvitedUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class DisplayPriorityRequest(BaseModel):
|
||||
display_priority_map: dict[int, int]
|
||||
|
||||
@@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import extract_headers
|
||||
@@ -190,7 +190,7 @@ def update_chat_session_model(
|
||||
def get_chat_session(
|
||||
session_id: UUID,
|
||||
is_shared: bool = False,
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionDetailResponse:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -246,7 +246,7 @@ def get_chat_session(
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
chat_session_creation_request: ChatSessionCreationRequest,
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
) -> StreamingResponse:
|
||||
@@ -473,7 +473,7 @@ def set_message_as_latest(
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
|
||||
|
||||
|
||||
def check_token_rate_limits(
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
) -> None:
|
||||
# short circuit if no rate limits are set up
|
||||
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
||||
|
||||
@@ -32,15 +32,15 @@ class InCodeToolInfo(TypedDict):
|
||||
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
InCodeToolInfo(
|
||||
cls=SearchTool,
|
||||
description="The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
in_code_tool_id=SearchTool.__name__,
|
||||
display_name=SearchTool._DISPLAY_NAME,
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=ImageGenerationTool,
|
||||
description=(
|
||||
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
|
||||
"The action will be used when the user asks the assistant to generate an image."
|
||||
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
|
||||
"The tool will be used when the user asks the assistant to generate an image."
|
||||
),
|
||||
in_code_tool_id=ImageGenerationTool.__name__,
|
||||
display_name=ImageGenerationTool._DISPLAY_NAME,
|
||||
@@ -51,7 +51,7 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
InCodeToolInfo(
|
||||
cls=InternetSearchTool,
|
||||
description=(
|
||||
"The Internet Search Action allows the assistant "
|
||||
"The Internet Search Tool allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
in_code_tool_id=InternetSearchTool.__name__,
|
||||
@@ -98,7 +98,7 @@ def load_builtin_tools(db_session: Session) -> None:
|
||||
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
|
||||
if tool_id not in built_in_ids:
|
||||
db_session.delete(tool)
|
||||
logger.notice(f"Removed action no longer in built-in list: {tool.name}")
|
||||
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
|
||||
|
||||
db_session.commit()
|
||||
logger.notice("All built-in tools are loaded/verified.")
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlencode
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
|
||||
def add_url_params(url: str, params: dict) -> str:
|
||||
"""
|
||||
Add parameters to a URL, handling existing parameters properly.
|
||||
|
||||
Args:
|
||||
url: The original URL
|
||||
params: Dictionary of parameters to add
|
||||
|
||||
Returns:
|
||||
URL with added parameters
|
||||
"""
|
||||
# Parse the URL
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
# Get existing query parameters
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
# Update with new parameters
|
||||
for key, value in params.items():
|
||||
query_params[key] = [value]
|
||||
|
||||
# Build the new query string
|
||||
new_query = urlencode(query_params, doseq=True)
|
||||
|
||||
# Reconstruct the URL with the new query string
|
||||
new_url = urlunparse(
|
||||
(
|
||||
parsed_url.scheme,
|
||||
parsed_url.netloc,
|
||||
parsed_url.path,
|
||||
parsed_url.params,
|
||||
new_query,
|
||||
parsed_url.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
return new_url
|
||||
@@ -1,4 +1,4 @@
|
||||
black==23.7.0
|
||||
black==23.3.0
|
||||
boto3-stubs[s3]==1.34.133
|
||||
celery-types==0.19.0
|
||||
cohere==5.6.1
|
||||
|
||||
@@ -54,7 +54,6 @@ class OnyxRedisCommand(Enum):
|
||||
purge_vespa_syncing = "purge_vespa_syncing"
|
||||
get_user_token = "get_user_token"
|
||||
delete_user_token = "delete_user_token"
|
||||
add_invited_user = "add_invited_user"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
@@ -164,21 +163,6 @@ def onyx_redis(
|
||||
return 0
|
||||
else:
|
||||
return 2
|
||||
elif command == OnyxRedisCommand.add_invited_user:
|
||||
if not user_email:
|
||||
logger.error("You must specify --user-email with add_invited_user")
|
||||
return 1
|
||||
current_invited_users = get_invited_users()
|
||||
if user_email not in current_invited_users:
|
||||
current_invited_users.append(user_email)
|
||||
if dry_run:
|
||||
logger.info(f"(DRY-RUN) Would add {user_email} to invited users")
|
||||
else:
|
||||
write_invited_users(current_invited_users)
|
||||
logger.info(f"Added {user_email} to invited users")
|
||||
else:
|
||||
logger.info(f"{user_email} is already in the invited users list")
|
||||
return 0
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -457,6 +441,23 @@ if __name__ == "__main__":
|
||||
if args.tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id)
|
||||
|
||||
if args.command == "add_invited_user":
|
||||
if not args.user_email:
|
||||
print("Error: --user-email is required for add_invited_user command")
|
||||
sys.exit(1)
|
||||
|
||||
current_invited_users = get_invited_users()
|
||||
if args.user_email not in current_invited_users:
|
||||
current_invited_users.append(args.user_email)
|
||||
if args.dry_run:
|
||||
print(f"(DRY-RUN) Would add {args.user_email} to invited users")
|
||||
else:
|
||||
write_invited_users(current_invited_users)
|
||||
print(f"Added {args.user_email} to invited users")
|
||||
else:
|
||||
print(f"{args.user_email} is already in the invited users list")
|
||||
sys.exit(0)
|
||||
|
||||
exitcode = onyx_redis(
|
||||
command=args.command,
|
||||
batch=args.batch,
|
||||
|
||||
@@ -36,7 +36,6 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
@pytest.mark.skip(reason="Skipping this test")
|
||||
def test_confluence_connector_basic(
|
||||
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
|
||||
) -> None:
|
||||
|
||||
@@ -28,7 +28,6 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
|
||||
# This should never fail because even if the docs in the cloud change,
|
||||
# the full doc ids retrieved should always be a subset of the slim doc ids
|
||||
@pytest.mark.skip(reason="Skipping this test")
|
||||
def test_confluence_connector_permissions(
|
||||
confluence_connector: ConfluenceConnector,
|
||||
) -> None:
|
||||
|
||||
@@ -20,32 +20,29 @@ def gitbook_connector() -> GitbookConnector:
|
||||
return connector
|
||||
|
||||
|
||||
NUM_PAGES = 3
|
||||
|
||||
|
||||
def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
|
||||
doc_batch_generator = gitbook_connector.load_from_state()
|
||||
|
||||
# Get first batch of documents
|
||||
doc_batch = next(doc_batch_generator)
|
||||
assert len(doc_batch) == NUM_PAGES
|
||||
assert len(doc_batch) > 0
|
||||
|
||||
# Verify first document structure
|
||||
main_doc = doc_batch[0]
|
||||
doc = doc_batch[0]
|
||||
|
||||
# Basic document properties
|
||||
assert main_doc.id.startswith("gitbook-")
|
||||
assert main_doc.semantic_identifier == "Acme Corp Internal Handbook"
|
||||
assert main_doc.source == DocumentSource.GITBOOK
|
||||
assert doc.id.startswith("gitbook-")
|
||||
assert doc.semantic_identifier == "Acme Corp Internal Handbook"
|
||||
assert doc.source == DocumentSource.GITBOOK
|
||||
|
||||
# Metadata checks
|
||||
assert "path" in main_doc.metadata
|
||||
assert "type" in main_doc.metadata
|
||||
assert "kind" in main_doc.metadata
|
||||
assert "path" in doc.metadata
|
||||
assert "type" in doc.metadata
|
||||
assert "kind" in doc.metadata
|
||||
|
||||
# Section checks
|
||||
assert len(main_doc.sections) == 1
|
||||
section = main_doc.sections[0]
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
|
||||
# Content specific checks
|
||||
content = section.text
|
||||
@@ -77,23 +74,8 @@ def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
|
||||
|
||||
assert section.link # Should have a URL
|
||||
|
||||
nested1 = doc_batch[1]
|
||||
assert nested1.id.startswith("gitbook-")
|
||||
assert nested1.semantic_identifier == "Nested1"
|
||||
assert len(nested1.sections) == 1
|
||||
# extra newlines at the end, remove them to make test easier
|
||||
assert nested1.sections[0].text.strip() == "nested1"
|
||||
assert nested1.source == DocumentSource.GITBOOK
|
||||
|
||||
nested2 = doc_batch[2]
|
||||
assert nested2.id.startswith("gitbook-")
|
||||
assert nested2.semantic_identifier == "Nested2"
|
||||
assert len(nested2.sections) == 1
|
||||
assert nested2.sections[0].text.strip() == "nested2"
|
||||
assert nested2.source == DocumentSource.GITBOOK
|
||||
|
||||
# Time-based polling test
|
||||
current_time = time.time()
|
||||
poll_docs = gitbook_connector.poll_source(0, current_time)
|
||||
poll_batch = next(poll_docs)
|
||||
assert len(poll_batch) == NUM_PAGES
|
||||
assert len(poll_batch) > 0
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_connector() -> NotionConnector:
|
||||
"""Create a NotionConnector with credentials from environment variables"""
|
||||
connector = NotionConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"notion_integration_token": os.environ["NOTION_INTEGRATION_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_notion_connector_basic(notion_connector: NotionConnector) -> None:
|
||||
"""Test the NotionConnector with a real Notion page.
|
||||
|
||||
Uses a Notion workspace under the onyx-test.com domain.
|
||||
"""
|
||||
doc_batch_generator = notion_connector.poll_source(0, time.time())
|
||||
|
||||
# Get first batch of documents
|
||||
doc_batch = next(doc_batch_generator)
|
||||
assert (
|
||||
len(doc_batch) == 5
|
||||
), "Expected exactly 5 documents (root, two children, table entry, and table entry child)"
|
||||
|
||||
# Find root and child documents by semantic identifier
|
||||
root_doc = None
|
||||
child1_doc = None
|
||||
child2_doc = None
|
||||
table_entry_doc = None
|
||||
table_entry_child_doc = None
|
||||
for doc in doc_batch:
|
||||
if doc.semantic_identifier == "Root":
|
||||
root_doc = doc
|
||||
elif doc.semantic_identifier == "Child1":
|
||||
child1_doc = doc
|
||||
elif doc.semantic_identifier == "Child2":
|
||||
child2_doc = doc
|
||||
elif doc.semantic_identifier == "table-entry01":
|
||||
table_entry_doc = doc
|
||||
elif doc.semantic_identifier == "Child-table-entry01":
|
||||
table_entry_child_doc = doc
|
||||
|
||||
assert root_doc is not None, "Root document not found"
|
||||
assert child1_doc is not None, "Child1 document not found"
|
||||
assert child2_doc is not None, "Child2 document not found"
|
||||
assert table_entry_doc is not None, "Table entry document not found"
|
||||
assert table_entry_child_doc is not None, "Table entry child document not found"
|
||||
|
||||
# Verify root document structure
|
||||
assert root_doc.id is not None
|
||||
assert root_doc.source == DocumentSource.NOTION
|
||||
|
||||
# Section checks for root
|
||||
assert len(root_doc.sections) == 1
|
||||
root_section = root_doc.sections[0]
|
||||
|
||||
# Content specific checks for root
|
||||
assert root_section.text == "\nroot"
|
||||
assert root_section.link is not None
|
||||
assert root_section.link.startswith("https://www.notion.so/")
|
||||
|
||||
# Verify child1 document structure
|
||||
assert child1_doc.id is not None
|
||||
assert child1_doc.source == DocumentSource.NOTION
|
||||
|
||||
# Section checks for child1
|
||||
assert len(child1_doc.sections) == 1
|
||||
child1_section = child1_doc.sections[0]
|
||||
|
||||
# Content specific checks for child1
|
||||
assert child1_section.text == "\nchild1"
|
||||
assert child1_section.link is not None
|
||||
assert child1_section.link.startswith("https://www.notion.so/")
|
||||
|
||||
# Verify child2 document structure (includes database)
|
||||
assert child2_doc.id is not None
|
||||
assert child2_doc.source == DocumentSource.NOTION
|
||||
|
||||
# Section checks for child2
|
||||
assert len(child2_doc.sections) == 2 # One for content, one for database
|
||||
child2_section = child2_doc.sections[0]
|
||||
child2_db_section = child2_doc.sections[1]
|
||||
|
||||
# Content specific checks for child2
|
||||
assert child2_section.text == "\nchild2"
|
||||
assert child2_section.link is not None
|
||||
assert child2_section.link.startswith("https://www.notion.so/")
|
||||
|
||||
# Database section checks for child2
|
||||
assert child2_db_section.text.strip() != "" # Should contain some database content
|
||||
assert child2_db_section.link is not None
|
||||
assert child2_db_section.link.startswith("https://www.notion.so/")
|
||||
|
||||
# Verify table entry document structure
|
||||
assert table_entry_doc.id is not None
|
||||
assert table_entry_doc.source == DocumentSource.NOTION
|
||||
|
||||
# Section checks for table entry
|
||||
assert len(table_entry_doc.sections) == 1
|
||||
table_entry_section = table_entry_doc.sections[0]
|
||||
|
||||
# Content specific checks for table entry
|
||||
assert table_entry_section.text == "\ntable-entry01"
|
||||
assert table_entry_section.link is not None
|
||||
assert table_entry_section.link.startswith("https://www.notion.so/")
|
||||
|
||||
# Verify table entry child document structure
|
||||
assert table_entry_child_doc.id is not None
|
||||
assert table_entry_child_doc.source == DocumentSource.NOTION
|
||||
|
||||
# Section checks for table entry child
|
||||
assert len(table_entry_child_doc.sections) == 1
|
||||
table_entry_child_section = table_entry_child_doc.sections[0]
|
||||
|
||||
# Content specific checks for table entry child
|
||||
assert table_entry_child_section.text == "\nchild-table-entry01"
|
||||
assert table_entry_child_section.link is not None
|
||||
assert table_entry_child_section.link.startswith("https://www.notion.so/")
|
||||
@@ -2,7 +2,7 @@ FROM python:3.11.7-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install "pydantic-core>=2.28.0" fastapi uvicorn
|
||||
RUN pip install fastapi uvicorn
|
||||
|
||||
COPY ./main.py /app/main.py
|
||||
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
@@ -22,58 +17,3 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
)
|
||||
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="enable for autorun when we have a testing environment with semantically useful data"
|
||||
)
|
||||
def test_send_message_simple_with_history_buffered() -> None:
|
||||
import requests
|
||||
|
||||
API_KEY = "" # fill in for this to work
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {API_KEY}"
|
||||
|
||||
req: dict[str, Any] = {}
|
||||
|
||||
req["persona_id"] = 0
|
||||
req["description"] = "test_send_message_simple_with_history_buffered"
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session", headers=headers, json=req
|
||||
)
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
|
||||
req = {}
|
||||
req["chat_session_id"] = chat_session_id
|
||||
req["message"] = "What does onyx do?"
|
||||
req["use_agentic_search"] = True
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-api", headers=headers, json=req
|
||||
)
|
||||
|
||||
r_json = response.json()
|
||||
|
||||
# all of these should exist and be greater than length 1
|
||||
assert len(r_json.get("answer", "")) > 0
|
||||
assert len(r_json.get("agent_sub_questions", "")) > 0
|
||||
assert len(r_json.get("agent_answers")) > 0
|
||||
assert len(r_json.get("agent_sub_queries")) > 0
|
||||
assert "agent_refined_answer_improvement" in r_json
|
||||
|
||||
# top level answer should match the one we select out of agent_answers
|
||||
answer_level = 0
|
||||
agent_level_answer = ""
|
||||
|
||||
agent_refined_answer_improvement = r_json.get("agent_refined_answer_improvement")
|
||||
if agent_refined_answer_improvement:
|
||||
answer_level = len(r_json["agent_answers"]) - 1
|
||||
|
||||
answers = r_json["agent_answers"][str(answer_level)]
|
||||
for answer in answers:
|
||||
if answer["answer_type"] == "agent_level_answer":
|
||||
agent_level_answer = answer["answer"]
|
||||
break
|
||||
|
||||
assert r_json["answer"] == agent_level_answer
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -1095,7 +1095,8 @@ export function AssistantEditor({
|
||||
|
||||
{values.is_public ? (
|
||||
<p className="text-sm text-text-dark">
|
||||
Anyone from your team can view and use this assistant
|
||||
Anyone from your organization can view and use this
|
||||
assistant
|
||||
</p>
|
||||
) : (
|
||||
<>
|
||||
|
||||
@@ -177,11 +177,6 @@ export function PersonasTable() {
|
||||
entityName={personaToToggleDefault.name}
|
||||
onClose={closeDefaultModal}
|
||||
onSubmit={handleToggleDefault}
|
||||
actionText={
|
||||
personaToToggleDefault.is_default_persona
|
||||
? "remove the featured status of"
|
||||
: "set as featured"
|
||||
}
|
||||
actionButtonText={
|
||||
personaToToggleDefault.is_default_persona
|
||||
? "Remove Featured"
|
||||
|
||||
@@ -121,7 +121,7 @@ function Main() {
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
function Page() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle
|
||||
@@ -132,3 +132,5 @@ export default function Page() {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Page;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -7,14 +8,10 @@ import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGoogleDriveOAuth } from "@/lib/googleDrive";
|
||||
import { GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import {
|
||||
TextFormField,
|
||||
SectionHeader,
|
||||
SubLabel,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Button as TremorButton } from "@/components/ui/button";
|
||||
import {
|
||||
Credential,
|
||||
GoogleDriveCredentialJson,
|
||||
@@ -23,15 +20,6 @@ import {
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import {
|
||||
FiFile,
|
||||
FiUpload,
|
||||
FiTrash2,
|
||||
FiCheck,
|
||||
FiLink,
|
||||
FiAlertTriangle,
|
||||
} from "react-icons/fi";
|
||||
import { cn, truncateString } from "@/lib/utils";
|
||||
|
||||
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
@@ -43,202 +31,126 @@ export const DriveJsonUpload = ({
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [fileName, setFileName] = useState<string | undefined>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setFileName(file.name);
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (loadEvent) => {
|
||||
if (!loadEvent?.target?.result) {
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialJsonStr = loadEvent.target.result as string;
|
||||
|
||||
// Check credential type
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsUploading(false);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!isUploading) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
if (isUploading) return;
|
||||
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
const file = files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
handleFileUpload(file);
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please upload a JSON file",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col mt-4">
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
isUploading
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: isDragging
|
||||
? "bg-background-50/50 border-primary dark:border-primary"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{isUploading ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{isUploading
|
||||
? `Uploading ${truncateString(fileName || "file", 50)}...`
|
||||
: isDragging
|
||||
? "Drop JSON file here"
|
||||
: truncateString(
|
||||
fileName || "Select or drag JSON credentials file...",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
className="sr-only"
|
||||
type="file"
|
||||
accept=".json"
|
||||
disabled={isUploading}
|
||||
onChange={(event) => {
|
||||
if (!event.target.files?.length) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
handleFileUpload(file);
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 " +
|
||||
"cursor-pointer bg-backgrournd dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -248,7 +160,6 @@ interface DriveJsonUploadSectionProps {
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
existingAuthCredential?: boolean;
|
||||
}
|
||||
|
||||
export const DriveJsonUploadSection = ({
|
||||
@@ -257,7 +168,6 @@ export const DriveJsonUploadSection = ({
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
existingAuthCredential,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
@@ -267,7 +177,6 @@ export const DriveJsonUploadSection = ({
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountCredentialData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
@@ -281,135 +190,153 @@ export const DriveJsonUploadSection = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (!isAdmin) {
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Curators are unable to set up the Google Drive credentials. To add a
|
||||
Google Drive connector, please contact an administrator.
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/app-credential"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAdmin) {
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Curators are unable to set up the google drive credentials. To add a
|
||||
Google Drive connector, please contact an administrator.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<p className="text-sm mb-3">
|
||||
To connect your Google Drive, create credentials (either OAuth App or
|
||||
Service Account), download the JSON file, and upload it below.
|
||||
</p>
|
||||
<div className="mb-4">
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Follow the guide{" "}
|
||||
<a
|
||||
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
|
||||
className="text-link"
|
||||
target="_blank"
|
||||
href="https://docs.onyx.app/connectors/google_drive#authorization"
|
||||
rel="noreferrer"
|
||||
>
|
||||
<FiLink className="h-3 w-3" />
|
||||
View detailed setup instructions
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{(localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id) && (
|
||||
<div className="mb-4">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
false
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{false ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{truncateString(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id ||
|
||||
"",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
{isAdmin && !existingAuthCredential && (
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
const endpoint =
|
||||
localServiceAccountData?.service_account_email
|
||||
? "/api/manage/admin/connector/google-drive/service-account-key"
|
||||
: "/api/manage/admin/connector/google-drive/app-credential";
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
mutate(endpoint);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
|
||||
// Add additional mutations to refresh all credential-related endpoints
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/credentials"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/public-credential"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential"
|
||||
);
|
||||
|
||||
setPopup({
|
||||
message: `Successfully deleted ${
|
||||
localServiceAccountData
|
||||
? "service account key"
|
||||
: "app credentials"
|
||||
}`,
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
if (localServiceAccountData) {
|
||||
setLocalServiceAccountData(undefined);
|
||||
} else {
|
||||
setLocalAppCredentialData(undefined);
|
||||
}
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id
|
||||
) && <DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />}
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if chooosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -464,7 +391,6 @@ export const DriveAuthSection = ({
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountKeyData
|
||||
);
|
||||
@@ -479,7 +405,6 @@ export const DriveAuthSection = ({
|
||||
setLocalGoogleDriveServiceAccountCredential,
|
||||
] = useState(googleDriveServiceAccountCredential);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountKeyData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
@@ -499,181 +424,126 @@ export const DriveAuthSection = ({
|
||||
localGoogleDriveServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
|
||||
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<div className="flex-1">
|
||||
<span className="font-medium block">Authentication Complete</span>
|
||||
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
|
||||
Your Google Drive credentials have been successfully uploaded
|
||||
and authenticated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorAssociated,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no credentials are uploaded, show message to complete step 1 first
|
||||
if (
|
||||
!localServiceAccountData?.service_account_email &&
|
||||
!localAppCredentialData?.client_id
|
||||
) {
|
||||
return (
|
||||
<div>
|
||||
<SectionHeader>Google Drive Authentication</SectionHeader>
|
||||
<div className="mt-4">
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Please complete Step 1 by uploading either OAuth credentials or a
|
||||
Service Account key before proceeding with authentication.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorAssociated,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string().required(
|
||||
"User email is required"
|
||||
),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/google-drive/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<TremorButton type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</TremorButton>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div>
|
||||
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
|
||||
<p className="text-sm">
|
||||
Next, you need to authenticate with Google Drive via OAuth. This
|
||||
gives us read access to the documents you have access to in your
|
||||
Google Drive account.
|
||||
</p>
|
||||
</div>
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the docs you have access to in your google drive account.
|
||||
</p>
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
if (authUrl) {
|
||||
// cookie used by callback to determine where to finally redirect to
|
||||
Cookies.set(GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
|
||||
if (authUrl) {
|
||||
router.push(authUrl);
|
||||
} else {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to authenticate with Google Drive - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
router.push(authUrl);
|
||||
return;
|
||||
}
|
||||
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
}}
|
||||
>
|
||||
{isAuthenticating
|
||||
? "Authenticating..."
|
||||
: "Authenticate with Google Drive"}
|
||||
Authenticate with Google Drive
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// This code path should not be reached with the new conditions above
|
||||
return null;
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Google Drive
|
||||
Service Account Key JSON in Step 1 before moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -165,10 +165,6 @@ const GDriveMain = ({
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
existingAuthCredential={Boolean(
|
||||
googleDrivePublicUploadedCredential ||
|
||||
googleDriveServiceAccountCredential
|
||||
)}
|
||||
/>
|
||||
|
||||
{isAdmin &&
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Button } from "@/components/Button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -8,11 +8,7 @@ import { adminDeleteCredential } from "@/lib/credential";
|
||||
import { setupGmailOAuth } from "@/lib/gmail";
|
||||
import { GMAIL_AUTH_IS_ADMIN_COOKIE_NAME } from "@/lib/constants";
|
||||
import Cookies from "js-cookie";
|
||||
import {
|
||||
TextFormField,
|
||||
SectionHeader,
|
||||
SubLabel,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
@@ -24,19 +20,10 @@ import {
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import {
|
||||
FiFile,
|
||||
FiUpload,
|
||||
FiTrash2,
|
||||
FiCheck,
|
||||
FiLink,
|
||||
FiAlertTriangle,
|
||||
} from "react-icons/fi";
|
||||
import { cn, truncateString } from "@/lib/utils";
|
||||
|
||||
type GmailCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
const GmailCredentialUpload = ({
|
||||
const DriveJsonUpload = ({
|
||||
setPopup,
|
||||
onSuccess,
|
||||
}: {
|
||||
@@ -44,210 +31,134 @@ const GmailCredentialUpload = ({
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [fileName, setFileName] = useState<string | undefined>();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
|
||||
const handleFileUpload = async (file: File) => {
|
||||
setIsUploading(true);
|
||||
setFileName(file.name);
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (loadEvent) => {
|
||||
if (!loadEvent?.target?.result) {
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const credentialJsonStr = loadEvent.target.result as string;
|
||||
|
||||
// Check credential type
|
||||
let credentialFileType: GmailCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsUploading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
setIsUploading(false);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
};
|
||||
|
||||
const handleDragEnter = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (!isUploading) {
|
||||
setIsDragging(true);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDragLeave = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
};
|
||||
|
||||
const handleDragOver = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
const handleDrop = (e: React.DragEvent<HTMLLabelElement>) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
setIsDragging(false);
|
||||
|
||||
if (isUploading) return;
|
||||
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
const file = files[0];
|
||||
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||
handleFileUpload(file);
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please upload a JSON file",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
|
||||
return (
|
||||
<div className="flex flex-col mt-4">
|
||||
<div className="flex items-center">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
isUploading
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: isDragging
|
||||
? "bg-background-50/50 border-primary dark:border-primary"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
onDragEnter={handleDragEnter}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDragOver={handleDragOver}
|
||||
onDrop={handleDrop}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{isUploading ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{isUploading
|
||||
? `Uploading ${truncateString(fileName || "file", 50)}...`
|
||||
: isDragging
|
||||
? "Drop JSON file here"
|
||||
: truncateString(
|
||||
fileName || "Select or drag JSON credentials file...",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
className="sr-only"
|
||||
type="file"
|
||||
accept=".json"
|
||||
disabled={isUploading}
|
||||
onChange={(event) => {
|
||||
if (!event.target.files?.length) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
handleFileUpload(file);
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 overflow-visible " +
|
||||
"cursor-pointer bg-background dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
// check if the JSON is a app credential or a service account credential
|
||||
let credentialFileType: GmailCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
if (appCredentialJson.web) {
|
||||
credentialFileType = "authorized_user";
|
||||
} else if (appCredentialJson.type === "service_account") {
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
setPopup({
|
||||
message: `Invalid file provided - ${e}`,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (credentialFileType === "authorized_user") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: credentialJsonStr,
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Upload
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
interface GmailJsonUploadSectionProps {
|
||||
interface DriveJsonUploadSectionProps {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
appCredentialData?: { client_id: string };
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
existingAuthCredential?: boolean;
|
||||
}
|
||||
|
||||
export const GmailJsonUploadSection = ({
|
||||
@@ -256,8 +167,7 @@ export const GmailJsonUploadSection = ({
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
existingAuthCredential,
|
||||
}: GmailJsonUploadSectionProps) => {
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
@@ -280,138 +190,156 @@ export const GmailJsonUploadSection = ({
|
||||
}
|
||||
};
|
||||
|
||||
if (!isAdmin) {
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Curators are unable to set up the Gmail credentials. To add a Gmail
|
||||
connector, please contact an administrator.
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-key"
|
||||
);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete service account key - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!isAdmin) {
|
||||
return (
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Curators are unable to set up the Gmail credentials. To add a Gmail
|
||||
connector, please contact an administrator.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<p className="text-sm mb-3">
|
||||
To connect your Gmail, create credentials (either OAuth App or Service
|
||||
Account), download the JSON file, and upload it below.
|
||||
</p>
|
||||
<div className="mb-4">
|
||||
<div className="mt-2">
|
||||
<p className="text-sm mb-2">
|
||||
Follow the guide{" "}
|
||||
<a
|
||||
className="text-primary hover:text-primary/80 flex items-center gap-1 text-sm"
|
||||
className="text-link"
|
||||
target="_blank"
|
||||
href="https://docs.onyx.app/connectors/gmail#authorization"
|
||||
rel="noreferrer"
|
||||
>
|
||||
<FiLink className="h-3 w-3" />
|
||||
View detailed setup instructions
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{(localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id) && (
|
||||
<div className="mb-4">
|
||||
<div className="relative flex flex-1 items-center">
|
||||
<label
|
||||
className={cn(
|
||||
"flex h-10 items-center justify-center w-full px-4 py-2 border border-dashed rounded-md transition-colors",
|
||||
false
|
||||
? "opacity-70 cursor-not-allowed border-background-400 bg-background-50/30"
|
||||
: "cursor-pointer hover:bg-background-50/30 hover:border-primary dark:hover:border-primary border-background-300 dark:border-background-600"
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center space-x-2">
|
||||
{false ? (
|
||||
<div className="h-4 w-4 border-t-2 border-b-2 border-primary rounded-full animate-spin"></div>
|
||||
) : (
|
||||
<FiFile className="h-4 w-4 text-text-500" />
|
||||
)}
|
||||
<span className="text-sm text-text-500">
|
||||
{truncateString(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id ||
|
||||
"",
|
||||
50
|
||||
)}
|
||||
</span>
|
||||
</div>
|
||||
</label>
|
||||
</div>
|
||||
{isAdmin && !existingAuthCredential && (
|
||||
<div className="mt-2">
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
const endpoint =
|
||||
localServiceAccountData?.service_account_email
|
||||
? "/api/manage/admin/connector/gmail/service-account-key"
|
||||
: "/api/manage/admin/connector/gmail/app-credential";
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
mutate(endpoint);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
|
||||
// Add additional mutations to refresh all credential-related endpoints
|
||||
mutate("/api/manage/admin/connector/gmail/credentials");
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/public-credential"
|
||||
);
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential"
|
||||
);
|
||||
|
||||
setPopup({
|
||||
message: `Successfully deleted ${
|
||||
localServiceAccountData
|
||||
? "service account key"
|
||||
: "app credentials"
|
||||
}`,
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
if (localServiceAccountData) {
|
||||
setLocalServiceAccountData(undefined);
|
||||
} else {
|
||||
setLocalAppCredentialData(undefined);
|
||||
}
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!(
|
||||
localServiceAccountData?.service_account_email ||
|
||||
localAppCredentialData?.client_id
|
||||
) && (
|
||||
<GmailCredentialUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
)}
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a Google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if choosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface GmailCredentialSectionProps {
|
||||
interface DriveCredentialSectionProps {
|
||||
gmailPublicCredential?: Credential<GmailCredentialJson>;
|
||||
gmailServiceAccountCredential?: Credential<GmailServiceAccountCredentialJson>;
|
||||
serviceAccountKeyData?: { service_account_email: string };
|
||||
@@ -459,7 +387,7 @@ export const GmailAuthSection = ({
|
||||
refreshCredentials,
|
||||
connectorExists,
|
||||
user,
|
||||
}: GmailCredentialSectionProps) => {
|
||||
}: DriveCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
@@ -492,141 +420,104 @@ export const GmailAuthSection = ({
|
||||
localGmailPublicCredential || localGmailServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<div className="py-3 px-4 bg-blue-50/30 dark:bg-blue-900/5 rounded mb-4 flex items-start">
|
||||
<FiCheck className="text-blue-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<div className="flex-1">
|
||||
<span className="font-medium block">Authentication Complete</span>
|
||||
<p className="text-sm mt-1 text-text-500 dark:text-text-400 break-words">
|
||||
Your Gmail credentials have been successfully uploaded and
|
||||
authenticated.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Button
|
||||
variant="destructive"
|
||||
type="button"
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// If no credentials are uploaded, show message to complete step 1 first
|
||||
if (
|
||||
!localServiceAccountData?.service_account_email &&
|
||||
!localAppCredentialData?.client_id
|
||||
) {
|
||||
return (
|
||||
<div>
|
||||
<SectionHeader>Gmail Authentication</SectionHeader>
|
||||
<div className="mt-4">
|
||||
<div className="flex items-start py-3 px-4 bg-yellow-50/30 dark:bg-yellow-900/5 rounded">
|
||||
<FiAlertTriangle className="text-yellow-500 h-5 w-5 mr-2 mt-0.5 flex-shrink-0" />
|
||||
<p className="text-sm">
|
||||
Please complete Step 1 by uploading either OAuth credentials or a
|
||||
Service Account key before proceeding with authentication.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<div className="mt-4">
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
google_primary_admin: values.google_primary_admin,
|
||||
}),
|
||||
}
|
||||
} catch (error) {
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
return (
|
||||
<div>
|
||||
<div className="bg-background-50/30 dark:bg-background-900/20 rounded mb-4">
|
||||
<p className="text-sm">
|
||||
Next, you need to authenticate with Gmail via OAuth. This gives us
|
||||
read access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
</div>
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
@@ -654,6 +545,7 @@ export const GmailAuthSection = ({
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
disabled={isAuthenticating}
|
||||
>
|
||||
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
|
||||
</Button>
|
||||
@@ -661,6 +553,11 @@ export const GmailAuthSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
// This code path should not be reached with the new conditions above
|
||||
return null;
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Gmail Service
|
||||
Account Key JSON in Step 1 before moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -173,9 +173,6 @@ export const GmailMain = () => {
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
existingAuthCredential={Boolean(
|
||||
gmailPublicUploadedCredential || gmailServiceAccountCredential
|
||||
)}
|
||||
/>
|
||||
|
||||
{isAdmin && hasUploadedCredentials && (
|
||||
|
||||
@@ -114,8 +114,8 @@ function Main() {
|
||||
<ul className="list-disc mt-2 ml-4 mb-2">
|
||||
<li>
|
||||
<Text>
|
||||
Set a global rate limit to control your team's overall token
|
||||
spend.
|
||||
Set a global rate limit to control your organization's overall
|
||||
token spend.
|
||||
</Text>
|
||||
</li>
|
||||
{isPaidEnterpriseFeaturesEnabled && (
|
||||
|
||||
@@ -2,7 +2,14 @@
|
||||
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Formik, Form, Field, ErrorMessage, FieldArray } from "formik";
|
||||
import {
|
||||
Formik,
|
||||
Form,
|
||||
Field,
|
||||
ErrorMessage,
|
||||
FieldArray,
|
||||
ArrayHelpers,
|
||||
} from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
@@ -42,7 +49,7 @@ function prettifyDefinition(definition: any) {
|
||||
return JSON.stringify(definition, null, 2);
|
||||
}
|
||||
|
||||
function ActionForm({
|
||||
function ToolForm({
|
||||
existingTool,
|
||||
values,
|
||||
setFieldValue,
|
||||
@@ -111,7 +118,7 @@ function ActionForm({
|
||||
<TextFormField
|
||||
name="definition"
|
||||
label="Definition"
|
||||
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this action."
|
||||
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this tool."
|
||||
placeholder="Enter your OpenAPI schema here"
|
||||
isTextArea={true}
|
||||
defaultHeight="h-96"
|
||||
@@ -178,7 +185,7 @@ function ActionForm({
|
||||
clipRule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
Learn more about actions in our documentation
|
||||
Learn more about tool calling in our documentation
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
@@ -222,7 +229,7 @@ function ActionForm({
|
||||
Custom Headers
|
||||
</h3>
|
||||
<p className="text-sm mb-6 text-text-600 italic">
|
||||
Specify custom headers for each request to this action's API.
|
||||
Specify custom headers for each request to this tool's API.
|
||||
</p>
|
||||
<FieldArray
|
||||
name="customHeaders"
|
||||
@@ -353,7 +360,7 @@ function ActionForm({
|
||||
type="submit"
|
||||
disabled={isSubmitting || !!definitionError}
|
||||
>
|
||||
{existingTool ? "Update Action" : "Create Action"}
|
||||
{existingTool ? "Update Tool" : "Create Tool"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
@@ -379,7 +386,7 @@ const ToolSchema = Yup.object().shape({
|
||||
passthrough_auth: Yup.boolean().default(false),
|
||||
});
|
||||
|
||||
export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [definitionError, setDefinitionError] = useState<string | null>(null);
|
||||
@@ -425,7 +432,7 @@ export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
try {
|
||||
definition = parseJsonWithTrailingCommas(values.definition);
|
||||
} catch (error) {
|
||||
setDefinitionError("Invalid JSON in action definition");
|
||||
setDefinitionError("Invalid JSON in tool definition");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -446,17 +453,17 @@ export function ActionEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
}
|
||||
if (response.error) {
|
||||
setPopup({
|
||||
message: "Failed to create action - " + response.error,
|
||||
message: "Failed to create tool - " + response.error,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
router.push(`/admin/actions?u=${Date.now()}`);
|
||||
router.push(`/admin/tools?u=${Date.now()}`);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => {
|
||||
return (
|
||||
<ActionForm
|
||||
<ToolForm
|
||||
existingTool={tool}
|
||||
values={values}
|
||||
setFieldValue={setFieldValue}
|
||||
@@ -15,7 +15,7 @@ import { TrashIcon } from "@/components/icons/icons";
|
||||
import { deleteCustomTool } from "@/lib/tools/edit";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
|
||||
export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
|
||||
export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
@@ -2,7 +2,7 @@ import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { ActionEditor } from "@/app/admin/actions/ActionEditor";
|
||||
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
|
||||
import { fetchToolByIdSS } from "@/lib/tools/fetchTools";
|
||||
import { DeleteToolButton } from "./DeleteToolButton";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
@@ -31,7 +31,7 @@ export default async function Page(props: {
|
||||
<div>
|
||||
<div>
|
||||
<CardSection>
|
||||
<ActionEditor tool={tool} />
|
||||
<ToolEditor tool={tool} />
|
||||
</CardSection>
|
||||
|
||||
<Title className="mt-12">Delete Tool</Title>
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { ActionEditor } from "@/app/admin/actions/ActionEditor";
|
||||
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { ToolIcon } from "@/components/icons/icons";
|
||||
@@ -12,12 +12,12 @@ export default function NewToolPage() {
|
||||
<BackButton />
|
||||
|
||||
<AdminPageTitle
|
||||
title="Create Action"
|
||||
title="Create Tool"
|
||||
icon={<ToolIcon size={32} className="my-auto" />}
|
||||
/>
|
||||
|
||||
<CardSection>
|
||||
<ActionEditor />
|
||||
<ToolEditor />
|
||||
</CardSection>
|
||||
</div>
|
||||
);
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ActionsTable } from "./ActionTable";
|
||||
import { ToolsTable } from "./ToolsTable";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { FiPlusSquare } from "react-icons/fi";
|
||||
import Link from "next/link";
|
||||
@@ -29,23 +29,23 @@ export default async function Page() {
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle
|
||||
icon={<ToolIcon size={32} className="my-auto" />}
|
||||
title="Actions"
|
||||
title="Tools"
|
||||
/>
|
||||
|
||||
<Text className="mb-2">
|
||||
Actions allow assistants to retrieve information or take actions.
|
||||
Tools allow assistants to retrieve information or take actions.
|
||||
</Text>
|
||||
|
||||
<div>
|
||||
<Separator />
|
||||
|
||||
<Title>Create an Action</Title>
|
||||
<CreateButton href="/admin/actions/new" text="New Action" />
|
||||
<Title>Create a Tool</Title>
|
||||
<CreateButton href="/admin/tools/new" text="New Tool" />
|
||||
|
||||
<Separator />
|
||||
|
||||
<Title>Existing Actions</Title>
|
||||
<ActionsTable tools={tools} />
|
||||
<Title>Existing Tools</Title>
|
||||
<ToolsTable tools={tools} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@@ -21,8 +21,7 @@ import { InvitedUserSnapshot } from "@/lib/types";
|
||||
import { SearchBar } from "@/components/search/SearchBar";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import PendingUsersTable from "@/components/admin/users/PendingUsersTable";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
const UsersTables = ({
|
||||
q,
|
||||
setPopup,
|
||||
@@ -45,15 +44,6 @@ const UsersTables = ({
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const {
|
||||
data: pendingUsers,
|
||||
error: pendingUsersError,
|
||||
isLoading: pendingUsersLoading,
|
||||
mutate: pendingUsersMutate,
|
||||
} = useSWR<InvitedUserSnapshot[]>(
|
||||
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
// Show loading animation only during the initial data fetch
|
||||
if (!validDomains) {
|
||||
return <ThreeDotsLoader />;
|
||||
@@ -73,9 +63,6 @@ const UsersTables = ({
|
||||
<TabsList>
|
||||
<TabsTrigger value="current">Current Users</TabsTrigger>
|
||||
<TabsTrigger value="invited">Invited Users</TabsTrigger>
|
||||
{NEXT_PUBLIC_CLOUD_ENABLED && (
|
||||
<TabsTrigger value="pending">Pending Users</TabsTrigger>
|
||||
)}
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="current">
|
||||
@@ -110,25 +97,6 @@ const UsersTables = ({
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
{NEXT_PUBLIC_CLOUD_ENABLED && (
|
||||
<TabsContent value="pending">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Pending Users</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<PendingUsersTable
|
||||
users={pendingUsers || []}
|
||||
setPopup={setPopup}
|
||||
mutate={pendingUsersMutate}
|
||||
error={pendingUsersError}
|
||||
isLoading={pendingUsersLoading}
|
||||
q={q}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
)}
|
||||
</Tabs>
|
||||
);
|
||||
};
|
||||
@@ -222,7 +190,7 @@ const AddUserButton = ({
|
||||
entityName="your Access Logic"
|
||||
onClose={() => setShowConfirmation(false)}
|
||||
onSubmit={handleConfirmFirstInvite}
|
||||
additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your team."
|
||||
additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your instance."
|
||||
actionButtonText="Continue"
|
||||
variant="action"
|
||||
/>
|
||||
|
||||
@@ -18,8 +18,8 @@ const Page = () => {
|
||||
need to either:
|
||||
</p>
|
||||
<ul className="list-disc text-left text-text-600 w-full pl-6 mx-auto">
|
||||
<li>Be invited to an existing Onyx team</li>
|
||||
<li>Create a new Onyx team</li>
|
||||
<li>Be invited to an existing Onyx organization</li>
|
||||
<li>Create a new Onyx organization</li>
|
||||
</ul>
|
||||
<div className="flex justify-center">
|
||||
<Link
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { User } from "@/lib/types";
|
||||
import {
|
||||
getCurrentUserSS,
|
||||
getAuthTypeMetadataSS,
|
||||
AuthTypeMetadata,
|
||||
getAuthUrlSS,
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { EmailPasswordForm } from "../login/EmailPasswordForm";
|
||||
import Text from "@/components/ui/text";
|
||||
import Link from "next/link";
|
||||
import { SignInButton } from "../login/SignInButton";
|
||||
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
|
||||
import AuthErrorDisplay from "@/components/auth/AuthErrorDisplay";
|
||||
|
||||
const Page = async (props: {
|
||||
searchParams?: Promise<{ [key: string]: string | string[] | undefined }>;
|
||||
}) => {
|
||||
const searchParams = await props.searchParams;
|
||||
const nextUrl = Array.isArray(searchParams?.next)
|
||||
? searchParams?.next[0]
|
||||
: searchParams?.next || null;
|
||||
|
||||
const defaultEmail = Array.isArray(searchParams?.email)
|
||||
? searchParams?.email[0]
|
||||
: searchParams?.email || null;
|
||||
|
||||
const teamName = Array.isArray(searchParams?.team)
|
||||
? searchParams?.team[0]
|
||||
: searchParams?.team || "your team";
|
||||
|
||||
// catch cases where the backend is completely unreachable here
|
||||
// without try / catch, will just raise an exception and the page
|
||||
// will not render
|
||||
let authTypeMetadata: AuthTypeMetadata | null = null;
|
||||
let currentUser: User | null = null;
|
||||
try {
|
||||
[authTypeMetadata, currentUser] = await Promise.all([
|
||||
getAuthTypeMetadataSS(),
|
||||
getCurrentUserSS(),
|
||||
]);
|
||||
} catch (e) {
|
||||
console.log(`Some fetch failed for the login page - ${e}`);
|
||||
}
|
||||
|
||||
// simply take the user to the home page if Auth is disabled
|
||||
if (authTypeMetadata?.authType === "disabled") {
|
||||
return redirect("/chat");
|
||||
}
|
||||
|
||||
// if user is already logged in, take them to the main app page
|
||||
if (currentUser && currentUser.is_active && !currentUser.is_anonymous_user) {
|
||||
if (!authTypeMetadata?.requiresVerification || currentUser.is_verified) {
|
||||
return redirect("/chat");
|
||||
}
|
||||
return redirect("/auth/waiting-on-verification");
|
||||
}
|
||||
const cloud = authTypeMetadata?.authType === "cloud";
|
||||
|
||||
// only enable this page if basic login is enabled
|
||||
if (authTypeMetadata?.authType !== "basic" && !cloud) {
|
||||
return redirect("/chat");
|
||||
}
|
||||
|
||||
let authUrl: string | null = null;
|
||||
if (cloud && authTypeMetadata) {
|
||||
authUrl = await getAuthUrlSS(authTypeMetadata.authType, null);
|
||||
}
|
||||
const emailDomain = defaultEmail?.split("@")[1];
|
||||
|
||||
return (
|
||||
<AuthFlowContainer authState="join">
|
||||
<HealthCheckBanner />
|
||||
<AuthErrorDisplay searchParams={searchParams} />
|
||||
|
||||
<>
|
||||
<div className="absolute top-10x w-full"></div>
|
||||
<div className="flex w-full flex-col justify-center">
|
||||
<h2 className="text-center text-xl text-strong font-bold">
|
||||
Re-authenticate to join team
|
||||
</h2>
|
||||
|
||||
{cloud && authUrl && (
|
||||
<div className="w-full justify-center">
|
||||
<SignInButton authorizeUrl={authUrl} authType="cloud" />
|
||||
<div className="flex items-center w-full my-4">
|
||||
<div className="flex-grow border-t border-background-300"></div>
|
||||
<span className="px-4 text-text-500">or</span>
|
||||
<div className="flex-grow border-t border-background-300"></div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<EmailPasswordForm
|
||||
isSignup
|
||||
isJoin
|
||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||
nextUrl={nextUrl}
|
||||
defaultEmail={defaultEmail}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
</AuthFlowContainer>
|
||||
);
|
||||
};
|
||||
|
||||
export default Page;
|
||||
@@ -13,7 +13,6 @@ import { set } from "lodash";
|
||||
import { NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED } from "@/lib/constants";
|
||||
import Link from "next/link";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function EmailPasswordForm({
|
||||
isSignup = false,
|
||||
@@ -21,18 +20,15 @@ export function EmailPasswordForm({
|
||||
referralSource,
|
||||
nextUrl,
|
||||
defaultEmail,
|
||||
isJoin = false,
|
||||
}: {
|
||||
isSignup?: boolean;
|
||||
shouldVerify?: boolean;
|
||||
referralSource?: string;
|
||||
nextUrl?: string | null;
|
||||
defaultEmail?: string | null;
|
||||
isJoin?: boolean;
|
||||
}) {
|
||||
const { user } = useUser();
|
||||
const { popup, setPopup } = usePopup();
|
||||
const router = useRouter();
|
||||
const [isWorking, setIsWorking] = useState(false);
|
||||
return (
|
||||
<>
|
||||
@@ -83,11 +79,6 @@ export function EmailPasswordForm({
|
||||
});
|
||||
setIsWorking(false);
|
||||
return;
|
||||
} else {
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: "Account created successfully. Please log in.",
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,9 +92,7 @@ export function EmailPasswordForm({
|
||||
window.location.href = "/auth/waiting-on-verification";
|
||||
} else {
|
||||
// See above comment
|
||||
window.location.href = nextUrl
|
||||
? encodeURI(nextUrl)
|
||||
: `/chat${isSignup && !isJoin ? "?new_team=true" : ""}`;
|
||||
window.location.href = nextUrl ? encodeURI(nextUrl) : "/";
|
||||
}
|
||||
} else {
|
||||
setIsWorking(false);
|
||||
@@ -146,12 +135,11 @@ export function EmailPasswordForm({
|
||||
/>
|
||||
|
||||
<Button
|
||||
variant="agent"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="mx-auto !py-4 w-full"
|
||||
>
|
||||
{isJoin ? "Join" : isSignup ? "Sign Up" : "Log In"}
|
||||
{isSignup ? "Sign Up" : "Log In"}
|
||||
</Button>
|
||||
{user?.is_anonymous_user && (
|
||||
<Link
|
||||
|
||||
@@ -51,16 +51,25 @@ export default function LoginPage({
|
||||
</div>
|
||||
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
|
||||
|
||||
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
|
||||
<div className="flex mt-4 justify-between">
|
||||
<div className="flex mt-4 justify-between">
|
||||
<Link
|
||||
href={`/auth/signup${
|
||||
searchParams?.next ? `?next=${searchParams.next}` : ""
|
||||
}`}
|
||||
className="text-link font-medium"
|
||||
>
|
||||
Create an account
|
||||
</Link>
|
||||
|
||||
{NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED && (
|
||||
<Link
|
||||
href="/auth/forgot-password"
|
||||
className="text-link font-medium"
|
||||
>
|
||||
Reset Password
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ export function SignInButton({
|
||||
|
||||
return (
|
||||
<a
|
||||
className="mx-auto mb-4 mt-6 py-3 w-full dark:text-neutral-300 text-neutral-600 border border-neutral-300 flex rounded cursor-pointer hover:border-neutral-400 transition-colors"
|
||||
className="mx-auto mb-4 mt-6 py-3 w-full text-neutral-100 bg-indigo-500 flex rounded cursor-pointer hover:bg-indigo-800"
|
||||
href={finalAuthorizeUrl}
|
||||
>
|
||||
{button}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { useContext, useState, useRef, useLayoutEffect } from "react";
|
||||
import { ChevronDownIcon } from "@/components/icons/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { MinimalMarkdown } from "@/components/chat/MinimalMarkdown";
|
||||
|
||||
export function ChatBanner() {
|
||||
const settings = useContext(SettingsContext);
|
||||
|
||||
@@ -109,6 +109,7 @@ import {
|
||||
} from "@/components/resizable/constants";
|
||||
import FixedLogo from "../../components/logo/FixedLogo";
|
||||
|
||||
import { MinimalMarkdown } from "@/components/chat/MinimalMarkdown";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
|
||||
import {
|
||||
@@ -137,7 +138,6 @@ import { useSidebarShortcut } from "@/lib/browserUtilities";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
|
||||
import { ErrorBanner } from "./message/Resubmit";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -215,7 +215,11 @@ export function ChatPage({
|
||||
const isInitialLoad = useRef(true);
|
||||
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
|
||||
|
||||
const { assistants: availableAssistants, pinnedAssistants } = useAssistants();
|
||||
const {
|
||||
assistants: availableAssistants,
|
||||
finalAssistants,
|
||||
pinnedAssistants,
|
||||
} = useAssistants();
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(
|
||||
!shouldShowWelcomeModal
|
||||
@@ -225,7 +229,7 @@ export function ChatPage({
|
||||
const slackChatId = searchParams.get("slackChatId");
|
||||
const existingChatIdRaw = searchParams.get("chatId");
|
||||
|
||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false);
|
||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
|
||||
|
||||
const existingChatSessionId = existingChatIdRaw ? existingChatIdRaw : null;
|
||||
|
||||
@@ -2447,7 +2451,7 @@ export function ChatPage({
|
||||
h-full
|
||||
${sidebarVisible ? "w-[200px]" : "w-[0px]"}
|
||||
`}
|
||||
/>
|
||||
></div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -6,7 +6,6 @@ import { Button } from "@/components/ui/button";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { transformLinkUri } from "@/lib/utils";
|
||||
|
||||
const ALL_USERS_INITIAL_POPUP_FLOW_COMPLETED =
|
||||
"allUsersInitialPopupFlowCompleted";
|
||||
@@ -45,26 +44,23 @@ export function ChatPopup() {
|
||||
return (
|
||||
<Modal width="w-3/6 xl:w-[700px]" title={popupTitle}>
|
||||
<>
|
||||
<div className="overflow-y-auto max-h-[90vh] py-8 px-4 text-left">
|
||||
<ReactMarkdown
|
||||
className="prose text-text-800 dark:text-neutral-100 max-w-full"
|
||||
components={{
|
||||
a: ({ node, ...props }) => (
|
||||
<a
|
||||
{...props}
|
||||
className="text-link hover:text-link-hover"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
/>
|
||||
),
|
||||
p: ({ node, ...props }) => <p {...props} className="text-sm" />,
|
||||
}}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{popupContent}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
<ReactMarkdown
|
||||
className="prose text-text-800 dark:text-neutral-100 max-w-full"
|
||||
components={{
|
||||
a: ({ node, ...props }) => (
|
||||
<a
|
||||
{...props}
|
||||
className="text-link hover:text-link-hover"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
/>
|
||||
),
|
||||
p: ({ node, ...props }) => <p {...props} className="text-sm" />,
|
||||
}}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
{popupContent}
|
||||
</ReactMarkdown>
|
||||
|
||||
{showConsentError && (
|
||||
<p className="text-red-500 text-sm mt-2">
|
||||
|
||||
@@ -53,7 +53,6 @@ import { copyAll, handleCopy } from "./copyingUtils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { RefreshCw } from "lucide-react";
|
||||
import { ErrorBanner, Resubmit } from "./Resubmit";
|
||||
import { transformLinkUri } from "@/lib/utils";
|
||||
|
||||
export const AgenticMessage = ({
|
||||
isStreamingQuestions,
|
||||
@@ -337,7 +336,6 @@ export const AgenticMessage = ({
|
||||
}}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{finalAlternativeContent}
|
||||
</ReactMarkdown>
|
||||
@@ -351,7 +349,6 @@ export const AgenticMessage = ({
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{streamedContent +
|
||||
(!isComplete && !secondLevelGenerating ? " [*]() " : "")}
|
||||
|
||||
@@ -160,9 +160,8 @@ export const MemoizedLink = memo(
|
||||
|
||||
const handleMouseDown = () => {
|
||||
let url = href || rest.children?.toString();
|
||||
|
||||
if (url && !url.includes("://")) {
|
||||
// Only add https:// if the URL doesn't already have a protocol
|
||||
if (url && !url.startsWith("http://") && !url.startsWith("https://")) {
|
||||
// Try to construct a valid URL
|
||||
const httpsUrl = `https://${url}`;
|
||||
try {
|
||||
new URL(httpsUrl);
|
||||
|
||||
@@ -71,7 +71,6 @@ import remarkMath from "remark-math";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
import { copyAll, handleCopy } from "./copyingUtils";
|
||||
import { transformLinkUri } from "@/lib/utils";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -349,7 +348,7 @@ export const AIMessage = ({
|
||||
a: anchorCallback,
|
||||
p: paragraphCallback,
|
||||
b: ({ node, className, children }: any) => {
|
||||
return <span className={className}>{children}</span>;
|
||||
return <span className={className}>||||{children}</span>;
|
||||
},
|
||||
code: ({ node, className, children }: any) => {
|
||||
const codeText = extractCodeText(
|
||||
@@ -382,7 +381,6 @@ export const AIMessage = ({
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
|
||||
@@ -16,15 +16,15 @@ import ReactMarkdown from "react-markdown";
|
||||
import { MemoizedAnchor } from "./MemoizedTextComponents";
|
||||
import { MemoizedParagraph } from "./MemoizedTextComponents";
|
||||
import { extractCodeText, preprocessLaTeX } from "./codeUtils";
|
||||
import remarkGfm from "remark-gfm";
|
||||
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { CodeBlock } from "./CodeBlock";
|
||||
import { CheckIcon, ChevronDown } from "lucide-react";
|
||||
import { PHASE_MIN_MS, useStreamingMessages } from "./StreamingMessages";
|
||||
import { CirclingArrowIcon } from "@/components/icons/icons";
|
||||
import { handleCopy } from "./copyingUtils";
|
||||
import { transformLinkUri } from "@/lib/utils";
|
||||
|
||||
export const StatusIndicator = ({ status }: { status: ToggleState }) => {
|
||||
return (
|
||||
@@ -301,7 +301,6 @@ const SubQuestionDisplay: React.FC<{
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[rehypeKatex]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
|
||||
@@ -62,13 +62,19 @@ export function extractCodeText(
|
||||
|
||||
// We must preprocess LaTeX in the LLM output to avoid improper formatting
|
||||
export const preprocessLaTeX = (content: string) => {
|
||||
// 1) Replace block-level LaTeX delimiters \[ \] with $$ $$
|
||||
const blockProcessedContent = content.replace(
|
||||
// 1) Escape dollar signs used outside of LaTeX context
|
||||
const escapedCurrencyContent = content.replace(
|
||||
/\$(\d+(?:\.\d*)?)/g,
|
||||
(_, p1) => `\\$${p1}`
|
||||
);
|
||||
|
||||
// 2) Replace block-level LaTeX delimiters \[ \] with $$ $$
|
||||
const blockProcessedContent = escapedCurrencyContent.replace(
|
||||
/\\\[([\s\S]*?)\\\]/g,
|
||||
(_, equation) => `$$${equation}$$`
|
||||
);
|
||||
|
||||
// 2) Replace inline LaTeX delimiters \( \) with $ $
|
||||
// 3) Replace inline LaTeX delimiters \( \) with $ $
|
||||
const inlineProcessedContent = blockProcessedContent.replace(
|
||||
/\\\(([\s\S]*?)\\\)/g,
|
||||
(_, equation) => `$${equation}$`
|
||||
@@ -76,3 +82,223 @@ export const preprocessLaTeX = (content: string) => {
|
||||
|
||||
return inlineProcessedContent;
|
||||
};
|
||||
|
||||
interface MarkdownSegment {
|
||||
type: "text" | "link" | "code" | "bold" | "italic" | "codeblock";
|
||||
text: string; // The visible/plain text
|
||||
raw: string; // The raw markdown including syntax
|
||||
length: number; // Length of the visible text
|
||||
}
|
||||
|
||||
export function parseMarkdownToSegments(markdown: string): MarkdownSegment[] {
|
||||
if (!markdown) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const segments: MarkdownSegment[] = [];
|
||||
let currentIndex = 0;
|
||||
const maxIterations = markdown.length * 2; // Prevent infinite loops
|
||||
let iterations = 0;
|
||||
|
||||
while (currentIndex < markdown.length && iterations < maxIterations) {
|
||||
iterations++;
|
||||
let matched = false;
|
||||
|
||||
// Check for code blocks first (they take precedence)
|
||||
const codeBlockMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^```(\w*)\n([\s\S]*?)```/);
|
||||
if (codeBlockMatch && codeBlockMatch[0]) {
|
||||
const [fullMatch, , code] = codeBlockMatch;
|
||||
segments.push({
|
||||
type: "codeblock",
|
||||
text: code || "",
|
||||
raw: fullMatch,
|
||||
length: (code || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for inline code
|
||||
const inlineCodeMatch = markdown.slice(currentIndex).match(/^`([^`]+)`/);
|
||||
if (inlineCodeMatch && inlineCodeMatch[0]) {
|
||||
const [fullMatch, code] = inlineCodeMatch;
|
||||
segments.push({
|
||||
type: "code",
|
||||
text: code || "",
|
||||
raw: fullMatch,
|
||||
length: (code || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for links
|
||||
const linkMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^\[([^\]]+)\]\(([^)]+)\)/);
|
||||
if (linkMatch && linkMatch[0]) {
|
||||
const [fullMatch, text] = linkMatch;
|
||||
segments.push({
|
||||
type: "link",
|
||||
text: text || "",
|
||||
raw: fullMatch,
|
||||
length: (text || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for bold
|
||||
const boldMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^(\*\*|__)([^*_\n]*?)\1/);
|
||||
if (boldMatch && boldMatch[0]) {
|
||||
const [fullMatch, , text] = boldMatch;
|
||||
segments.push({
|
||||
type: "bold",
|
||||
text: text || "",
|
||||
raw: fullMatch,
|
||||
length: (text || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for italic
|
||||
const italicMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^(\*|_)([^*_\n]+?)\1(?!\*|_)/);
|
||||
if (italicMatch && italicMatch[0]) {
|
||||
const [fullMatch, , text] = italicMatch;
|
||||
segments.push({
|
||||
type: "italic",
|
||||
text: text || "",
|
||||
raw: fullMatch,
|
||||
length: (text || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// If no matches were found, handle regular text
|
||||
if (!matched) {
|
||||
let nextSpecialChar = markdown.slice(currentIndex).search(/[`\[*_]/);
|
||||
if (nextSpecialChar === -1) {
|
||||
// No more special characters, add the rest as text
|
||||
const text = markdown.slice(currentIndex);
|
||||
if (text) {
|
||||
segments.push({
|
||||
type: "text",
|
||||
text: text,
|
||||
raw: text,
|
||||
length: text.length,
|
||||
});
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Add the text up to the next special character
|
||||
const text = markdown.slice(
|
||||
currentIndex,
|
||||
currentIndex + nextSpecialChar
|
||||
);
|
||||
if (text) {
|
||||
segments.push({
|
||||
type: "text",
|
||||
text: text,
|
||||
raw: text,
|
||||
length: text.length,
|
||||
});
|
||||
}
|
||||
currentIndex += nextSpecialChar;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
|
||||
export function getMarkdownForSelection(
|
||||
content: string,
|
||||
selectedText: string
|
||||
): string {
|
||||
const segments = parseMarkdownToSegments(content);
|
||||
|
||||
// Build plain text and create mapping to markdown segments
|
||||
let plainText = "";
|
||||
const markdownPieces: string[] = [];
|
||||
let currentPlainIndex = 0;
|
||||
|
||||
segments.forEach((segment) => {
|
||||
plainText += segment.text;
|
||||
markdownPieces.push(segment.raw);
|
||||
currentPlainIndex += segment.length;
|
||||
});
|
||||
|
||||
// Find the selection in the plain text
|
||||
const startIndex = plainText.indexOf(selectedText);
|
||||
if (startIndex === -1) {
|
||||
return selectedText;
|
||||
}
|
||||
|
||||
const endIndex = startIndex + selectedText.length;
|
||||
|
||||
// Find which segments the selection spans
|
||||
let currentIndex = 0;
|
||||
let result = "";
|
||||
let selectionStart = startIndex;
|
||||
let selectionEnd = endIndex;
|
||||
|
||||
segments.forEach((segment) => {
|
||||
const segmentStart = currentIndex;
|
||||
const segmentEnd = segmentStart + segment.length;
|
||||
|
||||
// Check if this segment overlaps with the selection
|
||||
if (segmentEnd > selectionStart && segmentStart < selectionEnd) {
|
||||
// Calculate how much of this segment to include
|
||||
const overlapStart = Math.max(0, selectionStart - segmentStart);
|
||||
const overlapEnd = Math.min(segment.length, selectionEnd - segmentStart);
|
||||
|
||||
if (segment.type === "text") {
|
||||
const textPortion = segment.text.slice(overlapStart, overlapEnd);
|
||||
result += textPortion;
|
||||
} else {
|
||||
// For markdown elements, wrap just the selected portion with the appropriate markdown
|
||||
const selectedPortion = segment.text.slice(overlapStart, overlapEnd);
|
||||
|
||||
switch (segment.type) {
|
||||
case "bold":
|
||||
result += `**${selectedPortion}**`;
|
||||
break;
|
||||
case "italic":
|
||||
result += `*${selectedPortion}*`;
|
||||
break;
|
||||
case "code":
|
||||
result += `\`${selectedPortion}\``;
|
||||
break;
|
||||
case "link":
|
||||
// For links, we need to preserve the URL if it exists in the raw markdown
|
||||
const urlMatch = segment.raw.match(/\]\((.*?)\)/);
|
||||
const url = urlMatch ? urlMatch[1] : "";
|
||||
result += `[${selectedPortion}](${url})`;
|
||||
break;
|
||||
case "codeblock":
|
||||
result += `\`\`\`\n${selectedPortion}\n\`\`\``;
|
||||
break;
|
||||
default:
|
||||
result += selectedPortion;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentIndex += segment.length;
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -117,8 +117,9 @@ export function ShareChatSessionModal({
|
||||
{shareLink ? (
|
||||
<div>
|
||||
<Text>
|
||||
This chat session is currently shared. Anyone in your team can
|
||||
view the message history using the following link:
|
||||
This chat session is currently shared. Anyone in your
|
||||
organization can view the message history using the following
|
||||
link:
|
||||
</Text>
|
||||
|
||||
<div className="flex mt-2">
|
||||
@@ -159,7 +160,7 @@ export function ShareChatSessionModal({
|
||||
<div>
|
||||
<Callout type="warning" title="Warning" className="mb-4">
|
||||
Please make sure that all content in this chat is safe to
|
||||
share with the whole team.
|
||||
share with the whole organization.
|
||||
</Callout>
|
||||
<div className="flex w-full justify-between">
|
||||
<Button
|
||||
|
||||
@@ -12,7 +12,7 @@ function Main() {
|
||||
<div className="mt-4">
|
||||
<Callout type="danger" title="Custom Analytics is not enabled.">
|
||||
To set up custom analytics scripts, please work with the team who
|
||||
setup Onyx in your team to set the{" "}
|
||||
setup Onyx in your organization to set the{" "}
|
||||
<i>CUSTOM_ANALYTICS_SECRET_KEY</i> environment variable.
|
||||
</Callout>
|
||||
</div>
|
||||
|
||||
@@ -140,7 +140,7 @@ export function WhitelabelingForm() {
|
||||
<TextFormField
|
||||
label="Application Name"
|
||||
name="application_name"
|
||||
subtext={`The custom name you are giving Onyx for your team. This will replace 'Onyx' everywhere in the UI.`}
|
||||
subtext={`The custom name you are giving Onyx for your organization. This will replace 'Onyx' everywhere in the UI.`}
|
||||
placeholder="Custom name which will replace 'Onyx'"
|
||||
disabled={isSubmitting}
|
||||
/>
|
||||
|
||||
@@ -202,10 +202,10 @@ export function ClientLayout({
|
||||
className="text-text-700"
|
||||
size={18}
|
||||
/>
|
||||
<div className="ml-1">Actions</div>
|
||||
<div className="ml-1">Tools</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/actions",
|
||||
link: "/admin/tools",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
|
||||
@@ -2,8 +2,19 @@
|
||||
"use client";
|
||||
import React, { useContext } from "react";
|
||||
import Link from "next/link";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { WarningCircle, WarningDiamond } from "@phosphor-icons/react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { CgArrowsExpandUpLeft } from "react-icons/cg";
|
||||
import LogoWithText from "@/components/header/LogoWithText";
|
||||
import { LogoComponent } from "@/components/logo/FixedLogo";
|
||||
|
||||
interface Item {
|
||||
|
||||
@@ -33,8 +33,6 @@ import Link from "next/link";
|
||||
import { CheckboxField } from "@/components/ui/checkbox";
|
||||
import { CheckedState } from "@radix-ui/react-checkbox";
|
||||
|
||||
import { transformLinkUri } from "@/lib/utils";
|
||||
|
||||
export function SectionHeader({
|
||||
children,
|
||||
}: {
|
||||
@@ -434,7 +432,6 @@ export const MarkdownFormField = ({
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert"
|
||||
remarkPlugins={[remarkGfm]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{field.value}
|
||||
</ReactMarkdown>
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
import { useState } from "react";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import {
|
||||
Table,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableBody,
|
||||
TableCell,
|
||||
} from "@/components/ui/table";
|
||||
import CenteredPageSelector from "./CenteredPageSelector";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { InvitedUserSnapshot } from "@/lib/types";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
import { CheckIcon } from "lucide-react";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
|
||||
const USERS_PER_PAGE = 10;
|
||||
|
||||
interface Props {
|
||||
users: InvitedUserSnapshot[];
|
||||
setPopup: (spec: PopupSpec) => void;
|
||||
mutate: () => void;
|
||||
error: FetchError | null;
|
||||
isLoading: boolean;
|
||||
q: string;
|
||||
}
|
||||
|
||||
const PendingUsersTable = ({
|
||||
users,
|
||||
setPopup,
|
||||
mutate,
|
||||
error,
|
||||
isLoading,
|
||||
q,
|
||||
}: Props) => {
|
||||
const [currentPageNum, setCurrentPageNum] = useState<number>(1);
|
||||
const [userToApprove, setUserToApprove] = useState<string | null>(null);
|
||||
|
||||
if (!users.length)
|
||||
return <p>Users that have requested to join will show up here</p>;
|
||||
|
||||
const totalPages = Math.ceil(users.length / USERS_PER_PAGE);
|
||||
|
||||
// Filter users based on the search query
|
||||
const filteredUsers = q
|
||||
? users.filter((user) => user.email.includes(q))
|
||||
: users;
|
||||
|
||||
// Get the current page of users
|
||||
const currentPageOfUsers = filteredUsers.slice(
|
||||
(currentPageNum - 1) * USERS_PER_PAGE,
|
||||
currentPageNum * USERS_PER_PAGE
|
||||
);
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Error loading pending users"
|
||||
errorMsg={error?.info?.detail}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const handleAcceptRequest = async (email: string) => {
|
||||
try {
|
||||
await fetch("/api/tenants/users/invite/approve", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ email }),
|
||||
});
|
||||
mutate();
|
||||
setUserToApprove(null);
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: "Failed to approve user request",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{userToApprove && (
|
||||
<ConfirmEntityModal
|
||||
entityType="Join Request"
|
||||
entityName={userToApprove}
|
||||
onClose={() => setUserToApprove(null)}
|
||||
onSubmit={() => handleAcceptRequest(userToApprove)}
|
||||
actionButtonText="Approve"
|
||||
actionText="approve the join request of"
|
||||
additionalDetails={`${userToApprove} has requested to join the team. Approving will add them as a user in this team.`}
|
||||
variant="action"
|
||||
accent
|
||||
removeConfirmationText
|
||||
/>
|
||||
)}
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Email</TableHead>
|
||||
<TableHead>
|
||||
<div className="flex justify-end">Actions</div>
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{currentPageOfUsers.length ? (
|
||||
currentPageOfUsers.map((user) => (
|
||||
<TableRow key={user.email}>
|
||||
<TableCell>{user.email}</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setUserToApprove(user.email)}
|
||||
>
|
||||
<CheckIcon className="h-4 w-4" />
|
||||
Accept Join Request
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableCell colSpan={2} className="h-24 text-center">
|
||||
{`No pending users found matching "${q}"`}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
{totalPages > 1 ? (
|
||||
<CenteredPageSelector
|
||||
currentPage={currentPageNum}
|
||||
totalPages={totalPages}
|
||||
onPageChange={setCurrentPageNum}
|
||||
/>
|
||||
) : null}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default PendingUsersTable;
|
||||
@@ -22,19 +22,19 @@ export const LeaveOrganizationButton = ({
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
const { trigger, isMutating } = useSWRMutation(
|
||||
"/api/tenants/leave-team",
|
||||
"/api/tenants/leave-organization",
|
||||
userMutationFetcher,
|
||||
{
|
||||
onSuccess: () => {
|
||||
mutate();
|
||||
setPopup({
|
||||
message: "Successfully left the team!",
|
||||
message: "Successfully left the organization!",
|
||||
type: "success",
|
||||
});
|
||||
},
|
||||
onError: (errorMsg) =>
|
||||
setPopup({
|
||||
message: `Unable to leave team - ${errorMsg}`,
|
||||
message: `Unable to leave organization - ${errorMsg}`,
|
||||
type: "error",
|
||||
}),
|
||||
}
|
||||
@@ -53,11 +53,11 @@ export const LeaveOrganizationButton = ({
|
||||
<ConfirmEntityModal
|
||||
variant="action"
|
||||
actionButtonText="Leave"
|
||||
entityType="team"
|
||||
entityName="your team"
|
||||
entityType="organization"
|
||||
entityName="your organization"
|
||||
onClose={() => setShowLeaveModal(false)}
|
||||
onSubmit={handleLeaveOrganization}
|
||||
additionalDetails="You will lose access to all team data and resources."
|
||||
additionalDetails="You will lose access to all organization data and resources."
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useEffect } from "react";
|
||||
import { usePopup } from "../admin/connectors/Popup";
|
||||
|
||||
const ERROR_MESSAGES = {
|
||||
Anonymous: "Your team does not have anonymous access enabled.",
|
||||
Anonymous: "Your organization does not have anonymous access enabled.",
|
||||
};
|
||||
|
||||
export default function AuthErrorDisplay({
|
||||
|
||||
@@ -6,7 +6,7 @@ export default function AuthFlowContainer({
|
||||
authState,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
authState?: "signup" | "login" | "join";
|
||||
authState?: "signup" | "login";
|
||||
}) {
|
||||
return (
|
||||
<div className="p-4 flex flex-col items-center justify-center min-h-screen bg-background">
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user