mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 00:35:46 +00:00
Compare commits
17 Commits
improved_a
...
github_lis
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9087320a06 | ||
|
|
b0af1458c0 | ||
|
|
bb67a7a122 | ||
|
|
e239dc31c1 | ||
|
|
027128502c | ||
|
|
a7a374dc81 | ||
|
|
facc8cc2fa | ||
|
|
2c0af0a0ca | ||
|
|
bfbc1cd954 | ||
|
|
626da583aa | ||
|
|
92faca139d | ||
|
|
cec05c5ee9 | ||
|
|
eaf054ef06 | ||
|
|
a7a1a24658 | ||
|
|
687122911d | ||
|
|
40953bd4fe | ||
|
|
a7acc07e79 |
@@ -0,0 +1,125 @@
|
||||
"""Update GitHub connector repo_name to repositories
|
||||
|
||||
Revision ID: 3934b1bc7b62
|
||||
Revises: b7c2b63c4a03
|
||||
Create Date: 2025-03-05 10:50:30.516962
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import logging
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3934b1bc7b62"
|
||||
down_revision = "b7c2b63c4a03"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all GitHub connectors
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
updated_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
logger.warning(f"Connector {connector_id} has no config, skipping")
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repo_name" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repositories instead of repo_name
|
||||
new_config = dict(config)
|
||||
repo_name_value = new_config.pop("repo_name")
|
||||
new_config["repositories"] = repo_name_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating connector {connector_id}: {str(e)}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
logger.debug(
|
||||
"Starting rollback of GitHub connectors from repositories to repo_name"
|
||||
)
|
||||
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
|
||||
|
||||
# Revert each GitHub connector to use repo_name instead of repositories
|
||||
reverted_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repositories" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repo_name instead of repositories
|
||||
new_config = dict(config)
|
||||
repositories_value = new_config.pop("repositories")
|
||||
new_config["repo_name"] = repositories_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"new_config": json.dumps(new_config), "connector_id": connector_id},
|
||||
)
|
||||
reverted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error reverting connector {connector_id}: {str(e)}")
|
||||
@@ -134,7 +134,9 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
) -> list[ChatSession]:
|
||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
||||
"""Sorted by oldest to newest, then by message id"""
|
||||
|
||||
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
|
||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||
|
||||
filters: list[ColumnElement | BinaryExpression] = [
|
||||
@@ -147,8 +149,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
subquery = (
|
||||
db_session.query(ChatSession.id, ChatSession.time_created)
|
||||
.filter(*filters)
|
||||
.order_by(ChatSession.id, time_order)
|
||||
.distinct(ChatSession.id)
|
||||
.order_by(asc_time_order)
|
||||
.limit(limit)
|
||||
.subquery()
|
||||
)
|
||||
@@ -164,7 +165,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(time_order, message_order)
|
||||
.order_by(asc_time_order, message_order)
|
||||
)
|
||||
|
||||
chat_sessions = query.all()
|
||||
|
||||
@@ -16,13 +16,18 @@ from onyx.db.models import UsageReport
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
|
||||
|
||||
# Gets skeletons of all message
|
||||
# Gets skeletons of all messages in the given range
|
||||
def get_empty_chat_messages_entries__paginated(
|
||||
db_session: Session,
|
||||
period: tuple[datetime, datetime],
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
|
||||
"""Returns a tuple where:
|
||||
first element is the most recent timestamp out of the sessions iterated
|
||||
- this timestamp can be used to paginate forward in time
|
||||
second element is a list of messages belonging to all the sessions iterated
|
||||
"""
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=period[0],
|
||||
end=period[1],
|
||||
@@ -52,18 +57,17 @@ def get_empty_chat_messages_entries__paginated(
|
||||
if len(chat_sessions) == 0:
|
||||
return None, []
|
||||
|
||||
return chat_sessions[0].time_created, message_skeletons
|
||||
return chat_sessions[-1].time_created, message_skeletons
|
||||
|
||||
|
||||
def get_all_empty_chat_message_entries(
|
||||
db_session: Session,
|
||||
period: tuple[datetime, datetime],
|
||||
) -> Generator[list[ChatMessageSkeleton], None, None]:
|
||||
"""period is the range of time over which to fetch messages."""
|
||||
initial_time: Optional[datetime] = period[0]
|
||||
ind = 0
|
||||
while True:
|
||||
ind += 1
|
||||
|
||||
# iterate from oldest to newest
|
||||
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
|
||||
db_session,
|
||||
period,
|
||||
|
||||
@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -26,7 +26,7 @@ from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.saml import router as saml_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.router import router as tenants_router
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -80,6 +80,7 @@ class ConfluenceCloudOAuth:
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"read:content-details:confluence%20" # for permission sync
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
@@ -1,98 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
@@ -1,143 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def complete_tenant_setup(tenant_id: str, email: str) -> None:
|
||||
"""
|
||||
Complete the tenant setup process asynchronously after the essential migrations
|
||||
have been applied. This includes:
|
||||
1. Running the remaining Alembic migrations
|
||||
2. Setting up Onyx
|
||||
3. Creating milestone records
|
||||
"""
|
||||
logger.info(f"Starting asynchronous tenant setup for tenant {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run the remaining Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
# Configure default API keys
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
# Setup Onyx
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
# Create milestone record
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"Asynchronous tenant setup completed for tenant {tenant_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to complete asynchronous tenant setup for tenant {tenant_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Attempting to upsert Cohere cloud embedding provider")
|
||||
upsert_cloud_embedding_provider(cloud_embedding_provider, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Cohere provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere provider configuration"
|
||||
)
|
||||
@@ -1,96 +0,0 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -67,19 +67,3 @@ class ProductGatingResponse(BaseModel):
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
status: str
|
||||
is_complete: bool
|
||||
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class RequestInviteRequest(BaseModel):
|
||||
email: str
|
||||
tenant_id: str
|
||||
|
||||
@@ -48,4 +48,5 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
@@ -6,28 +6,47 @@ import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.async_setup import complete_tenant_setup
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_essential_alembic_migrations
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,7 +55,11 @@ logger = logging.getLogger(__name__)
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
"""
|
||||
Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
@@ -96,19 +119,35 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run only the essential Alembic migrations needed for auth
|
||||
await asyncio.to_thread(run_essential_alembic_migrations, tenant_id)
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
# Add user to tenant immediately so they can log in
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Start the rest of the setup process asynchronously
|
||||
asyncio.create_task(complete_tenant_setup(tenant_id, email))
|
||||
|
||||
logger.info(f"Essential tenant provisioning completed for tenant {tenant_id}")
|
||||
logger.info(
|
||||
f"Remaining setup will continue asynchronously for tenant {tenant_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
@@ -164,43 +203,136 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Attempting to upsert Cohere cloud embedding provider")
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
logger.info("Successfully upserted Cohere cloud embedding provider")
|
||||
|
||||
logger.info("Updating search settings with Cohere embedding model details")
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
current_search_settings = result.scalars().first()
|
||||
|
||||
if current_search_settings:
|
||||
current_search_settings.model_name = (
|
||||
"embed-english-v3.0" # Cohere's latest model as of now
|
||||
)
|
||||
current_search_settings.model_dim = (
|
||||
1024 # Cohere's embed-english-v3.0 dimension
|
||||
)
|
||||
current_search_settings.provider_type = EmbeddingProvider.COHERE
|
||||
current_search_settings.index_name = (
|
||||
"danswer_chunk_cohere_embed_english_v3_0"
|
||||
)
|
||||
current_search_settings.query_prefix = ""
|
||||
current_search_settings.passage_prefix = ""
|
||||
db_session.commit()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No search settings specified, DB is not in a valid state"
|
||||
)
|
||||
logger.info("Fetching updated search settings to verify changes")
|
||||
updated_query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
updated_result = db_session.execute(updated_query)
|
||||
updated_result.scalars().first()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to configure Cohere embedding provider")
|
||||
else:
|
||||
logger.info(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
||||
|
||||
async def submit_to_hubspot(
|
||||
email: str, referral_source: str | None, request: Request
|
||||
) -> None:
|
||||
if not HUBSPOT_TRACKING_URL:
|
||||
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
|
||||
return
|
||||
|
||||
try:
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
referer = request.headers.get("referer", "")
|
||||
ip_address = request.client.host if request.client else ""
|
||||
# HubSpot tracking cookie
|
||||
hubspot_cookie = request.cookies.get("hubspotutk")
|
||||
|
||||
payload = {
|
||||
"email": email,
|
||||
"referral_source": referral_source or "",
|
||||
"user_agent": user_agent,
|
||||
"referer": referer,
|
||||
"ip_address": ip_address,
|
||||
}
|
||||
# IP address
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
HUBSPOT_TRACKING_URL,
|
||||
json=payload,
|
||||
timeout=5.0,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to submit to HubSpot: {response.status_code} {response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error submitting to HubSpot: {e}")
|
||||
data = {
|
||||
"fields": [
|
||||
{"name": "email", "value": email},
|
||||
{"name": "referral_source", "value": referral_source or ""},
|
||||
],
|
||||
"context": {
|
||||
"hutk": hubspot_cookie,
|
||||
"ipAddress": ip_address,
|
||||
"pageUri": str(request.url),
|
||||
"pageName": "User Registration",
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
if DEV_MODE:
|
||||
return
|
||||
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
@@ -209,14 +341,15 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant deletion failed: {error_text}")
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.user_invitations_api import (
|
||||
router as user_invitations_router,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# from ee.onyx.server.tenants.provisioning import get_tenant_setup_status
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Create a main router to include all sub-routers
|
||||
router = APIRouter()
|
||||
|
||||
# Include all the sub-routers
|
||||
router.include_router(admin_router)
|
||||
router.include_router(anonymous_users_router)
|
||||
router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
|
||||
|
||||
class TenantSetupStatusResponse(BaseModel):
|
||||
"""Response model for tenant setup status."""
|
||||
|
||||
tenant_id: str
|
||||
status: str
|
||||
is_complete: bool
|
||||
|
||||
|
||||
# Add the setup status endpoint directly to the main router
|
||||
@router.get("/tenants/setup-status", response_model=TenantSetupStatusResponse)
|
||||
async def get_setup_status(
|
||||
current_user: User = Depends(current_user),
|
||||
) -> TenantSetupStatusResponse:
|
||||
"""
|
||||
Get the current setup status for the tenant.
|
||||
This is used by the frontend to determine if the tenant setup is complete.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
# status = get_tenant_setup_status(tenant_id)
|
||||
|
||||
return TenantSetupStatusResponse(
|
||||
tenant_id=tenant_id, status="completed", is_complete=True
|
||||
)
|
||||
@@ -49,47 +49,6 @@ def run_alembic_migrations(schema_name: str) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def run_essential_alembic_migrations(schema_name: str) -> None:
|
||||
"""
|
||||
Run only the essential Alembic migrations up to the 465f78d9b7f9 revision.
|
||||
This is used for the auth flow to complete quickly, with the rest of the migrations
|
||||
and setup being deferred to run asynchronously.
|
||||
"""
|
||||
logger.info(f"Starting essential Alembic migrations for schema: {schema_name}")
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
)
|
||||
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically up to the specified revision
|
||||
command.upgrade(alembic_cfg, "465f78d9b7f9")
|
||||
|
||||
logger.info(
|
||||
f"Essential Alembic migrations completed successfully for schema: {schema_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Essential Alembic migration failed for schema {schema_name}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def create_schema_if_not_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
@@ -1,62 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
FORBIDDEN_COMMON_EMAIL_DOMAINS = [
|
||||
"gmail.com",
|
||||
"yahoo.com",
|
||||
"hotmail.com",
|
||||
"outlook.com",
|
||||
"icloud.com",
|
||||
"msn.com",
|
||||
"live.com",
|
||||
"msn.com",
|
||||
"hotmail.com",
|
||||
"hotmail.co.uk",
|
||||
"hotmail.fr",
|
||||
"hotmail.de",
|
||||
"hotmail.it",
|
||||
"hotmail.es",
|
||||
"hotmail.nl",
|
||||
"hotmail.pl",
|
||||
"hotmail.pt",
|
||||
"hotmail.ro",
|
||||
"hotmail.ru",
|
||||
"hotmail.sa",
|
||||
"hotmail.se",
|
||||
"hotmail.tr",
|
||||
"hotmail.tw",
|
||||
"hotmail.ua",
|
||||
"hotmail.us",
|
||||
"hotmail.vn",
|
||||
"hotmail.za",
|
||||
"hotmail.zw",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
if not user:
|
||||
return None
|
||||
domain = user.email.split("@")[1]
|
||||
if domain in FORBIDDEN_COMMON_EMAIL_DOMAINS:
|
||||
return None
|
||||
tenant_id = get_current_tenant_id()
|
||||
return TenantByDomainResponse(
|
||||
tenant_id=tenant_id, status="completed", is_complete=True
|
||||
)
|
||||
|
||||
# return get_tenant_by_domain_from_control_plane(domain, tenant_id)
|
||||
@@ -1,91 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.models import ApproveUserRequest
|
||||
from ee.onyx.server.tenants.models import PendingUserSnapshot
|
||||
from ee.onyx.server.tenants.models import RequestInviteRequest
|
||||
from ee.onyx.server.tenants.user_mapping import accept_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/request-invite")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
|
||||
|
||||
@router.post("/users/approve-invite")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
|
||||
|
||||
@router.post("/users/accept-invite")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
accept_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to accept invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to accept invitation")
|
||||
|
||||
|
||||
@router.post("/users/deny-invite")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deny_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to deny invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to deny invitation")
|
||||
@@ -1,52 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Request
|
||||
from fastapi_users import BaseUserManager
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
|
||||
from onyx.auth.essential_user import EssentialUser
|
||||
from onyx.auth.essential_user import get_essential_user_db
|
||||
from onyx.configs.app_configs import USER_MANAGER_SECRET
|
||||
|
||||
|
||||
class EssentialUserManager(UUIDIDMixin, BaseUserManager[EssentialUser, str]):
|
||||
"""
|
||||
A simplified user manager that only handles essential authentication operations.
|
||||
This is used during the initial tenant setup phase to avoid errors with missing columns.
|
||||
"""
|
||||
|
||||
reset_password_token_secret = USER_MANAGER_SECRET
|
||||
verification_token_secret = USER_MANAGER_SECRET
|
||||
|
||||
async def on_after_register(
|
||||
self, user: EssentialUser, request: Optional[Request] = None
|
||||
) -> None:
|
||||
"""
|
||||
Simplified post-registration hook.
|
||||
"""
|
||||
|
||||
async def on_after_forgot_password(
|
||||
self, user: EssentialUser, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
"""
|
||||
Simplified post-forgot-password hook.
|
||||
"""
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: EssentialUser, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
"""
|
||||
Simplified post-verification-request hook.
|
||||
"""
|
||||
|
||||
|
||||
async def get_essential_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_essential_user_db),
|
||||
) -> EssentialUserManager:
|
||||
"""
|
||||
Get a user manager that uses the essential user model.
|
||||
This avoids errors with missing columns during the initial tenant setup.
|
||||
"""
|
||||
yield EssentialUserManager(user_db)
|
||||
@@ -1,47 +0,0 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.ext.declarative import DeclarativeMeta
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from onyx.db.engine import get_async_session
|
||||
|
||||
Base: DeclarativeMeta = declarative_base()
|
||||
|
||||
|
||||
class EssentialUser(SQLAlchemyBaseUserTableUUID, Base):
|
||||
"""
|
||||
A simplified user model that only includes essential columns needed for authentication.
|
||||
This is used during the initial tenant setup phase to avoid errors with missing columns
|
||||
that would be added in later migrations.
|
||||
"""
|
||||
|
||||
__tablename__ = "user"
|
||||
|
||||
email: str = Column(String(length=320), unique=True, index=True, nullable=False)
|
||||
hashed_password: Optional[str] = Column(String(length=1024), nullable=True)
|
||||
is_active: bool = Column(Boolean, default=True, nullable=False)
|
||||
is_superuser: bool = Column(Boolean, default=False, nullable=False)
|
||||
is_verified: bool = Column(Boolean, default=False, nullable=False)
|
||||
|
||||
# Relationships are defined but not used in the essential auth flow
|
||||
oauth_accounts = relationship("OAuthAccount", lazy="joined")
|
||||
credentials = relationship("Credential", lazy="joined")
|
||||
|
||||
|
||||
async def get_essential_user_db(
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
|
||||
"""
|
||||
Get a user database that uses the essential user model.
|
||||
This avoids errors with missing columns during the initial tenant setup.
|
||||
"""
|
||||
yield SQLAlchemyUserDatabase(session, EssentialUser)
|
||||
@@ -587,14 +587,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
None,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"User attempted to login with invalid credentials: {str(e)}"
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
|
||||
@@ -642,14 +642,4 @@ MOCK_LLM_RESPONSE = (
|
||||
)
|
||||
|
||||
|
||||
# Image processing configurations
|
||||
ENABLE_IMAGE_EXTRACTION = (
|
||||
os.environ.get("ENABLE_IMAGE_EXTRACTION", "true").lower() == "true"
|
||||
)
|
||||
ENABLE_INDEXING_TIME_IMAGE_ANALYSIS = not (
|
||||
os.environ.get("DISABLE_INDEXING_TIME_IMAGE_ANALYSIS", "false").lower() == "true"
|
||||
)
|
||||
ENABLE_SEARCH_TIME_IMAGE_ANALYSIS = not (
|
||||
os.environ.get("DISABLE_SEARCH_TIME_IMAGE_ANALYSIS", "false").lower() == "true"
|
||||
)
|
||||
IMAGE_ANALYSIS_MAX_SIZE_MB = int(os.environ.get("IMAGE_ANALYSIS_MAX_SIZE_MB", "20"))
|
||||
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
|
||||
38
backend/onyx/configs/llm_configs.py
Normal file
38
backend/onyx/configs/llm_configs.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from onyx.configs.app_configs import DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
|
||||
from onyx.server.settings.store import load_settings
|
||||
|
||||
|
||||
def get_image_extraction_and_analysis_enabled() -> bool:
|
||||
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.image_extraction_and_analysis_enabled is not None:
|
||||
return settings.image_extraction_and_analysis_enabled
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_search_time_image_analysis_enabled() -> bool:
|
||||
"""Get search time image analysis enabled setting from workspace settings or fallback to False"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.search_time_image_analysis_enabled is not None:
|
||||
return settings.search_time_image_analysis_enabled
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_image_analysis_max_size_mb() -> int:
|
||||
"""Get image analysis max size MB setting from workspace settings or fallback to environment variable"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.image_analysis_max_size_mb is not None:
|
||||
return settings.image_analysis_max_size_mb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
|
||||
@@ -240,7 +240,7 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
|
||||
@@ -144,6 +144,12 @@ class OnyxConfluence:
|
||||
self.static_credentials = credential_json
|
||||
return credential_json, False
|
||||
|
||||
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID:
|
||||
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!")
|
||||
|
||||
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET:
|
||||
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!")
|
||||
|
||||
# check if we should refresh tokens. we're deciding to refresh halfway
|
||||
# to expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repo_name: str | None = None,
|
||||
repositories: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repo_name = repo_name
|
||||
self.repositories = repositories
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
@@ -157,11 +157,42 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_github_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
"""Get specific repositories based on comma-separated repo_name string."""
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
repos = []
|
||||
# Split repo_name by comma and strip whitespace
|
||||
repo_names = [
|
||||
name.strip() for name in (cast(str, self.repositories)).split(",")
|
||||
]
|
||||
|
||||
for repo_name in repo_names:
|
||||
if repo_name: # Skip empty strings
|
||||
try:
|
||||
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
|
||||
repos.append(repo)
|
||||
except GithubException as e:
|
||||
logger.warning(
|
||||
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
|
||||
)
|
||||
|
||||
return repos
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
@@ -189,11 +220,17 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = (
|
||||
[self._get_github_repo(self.github_client)]
|
||||
if self.repo_name
|
||||
else self._get_all_repos(self.github_client)
|
||||
)
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
@@ -268,11 +305,48 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
if self.repo_name:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repo_names = [name.strip() for name in self.repositories.split(",")]
|
||||
if not repo_names:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: No valid repository names provided."
|
||||
)
|
||||
|
||||
# Validate at least one repository exists and is accessible
|
||||
valid_repos = False
|
||||
validation_errors = []
|
||||
|
||||
for repo_name in repo_names:
|
||||
if not repo_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
valid_repos = True
|
||||
# If at least one repo is valid, we can proceed
|
||||
break
|
||||
except GithubException as e:
|
||||
validation_errors.append(
|
||||
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
|
||||
)
|
||||
|
||||
if not valid_repos:
|
||||
error_msg = (
|
||||
"None of the specified repositories could be accessed: "
|
||||
)
|
||||
error_msg += ", ".join(validation_errors)
|
||||
raise ConnectorValidationError(error_msg)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
try:
|
||||
@@ -298,10 +372,15 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
if self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
raise ConnectorValidationError(
|
||||
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
@@ -310,6 +389,7 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
@@ -321,7 +401,7 @@ if __name__ == "__main__":
|
||||
|
||||
connector = GithubConnector(
|
||||
repo_owner=os.environ["REPO_OWNER"],
|
||||
repo_name=os.environ["REPO_NAME"],
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Mixin for connectors that need vision capabilities.
|
||||
"""
|
||||
from onyx.configs.app_configs import ENABLE_INDEXING_TIME_IMAGE_ANALYSIS
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -30,7 +30,7 @@ class VisionEnabledConnector:
|
||||
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
||||
"""
|
||||
self.image_analysis_llm: LLM | None = None
|
||||
if ENABLE_INDEXING_TIME_IMAGE_ANALYSIS:
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
try:
|
||||
self.image_analysis_llm = get_default_llm_with_vision()
|
||||
if self.image_analysis_llm is None:
|
||||
|
||||
@@ -10,8 +10,8 @@ from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import ENABLE_SEARCH_TIME_IMAGE_ANALYSIS
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -413,7 +413,7 @@ def search_postprocessing(
|
||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||
# since we know this is the final order.
|
||||
# This way the user experience isn't delayed by the LLM step
|
||||
if ENABLE_SEARCH_TIME_IMAGE_ANALYSIS:
|
||||
if get_search_time_image_analysis_enabled():
|
||||
update_image_sections_with_query(
|
||||
retrieved_sections, search_query.query, llm
|
||||
)
|
||||
@@ -456,7 +456,7 @@ def search_postprocessing(
|
||||
_log_top_section_links(search_query.search_type.value, reranked_sections)
|
||||
|
||||
# Add the image processing step here
|
||||
if ENABLE_SEARCH_TIME_IMAGE_ANALYSIS:
|
||||
if get_search_time_image_analysis_enabled():
|
||||
update_image_sections_with_query(
|
||||
reranked_sections, search_query.query, llm
|
||||
)
|
||||
|
||||
53
backend/onyx/db/seeding/chat_history_seeding.py
Normal file
53
backend/onyx/db/seeding/chat_history_seeding.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
"""Utility function to seed chat history for testing.
|
||||
|
||||
num_sessions: the number of sessions to seed
|
||||
num_messages: the number of messages to seed per sessions
|
||||
days: the number of days looking backwards from the current time over which to randomize
|
||||
the times.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for y in range(0, num_sessions):
|
||||
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||
|
||||
# randomize all session times
|
||||
rows = db_session.query(ChatSession).all()
|
||||
for row in rows:
|
||||
row.time_created = datetime.utcnow() - timedelta(
|
||||
days=random.randint(0, days)
|
||||
)
|
||||
row.time_updated = row.time_created + timedelta(
|
||||
minutes=random.randint(0, 10)
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(row.id, db_session)
|
||||
|
||||
for x in range(0, num_messages):
|
||||
chat_message = create_new_chat_message(
|
||||
row.id,
|
||||
root_message,
|
||||
f"pytest_message_{x}",
|
||||
None,
|
||||
0,
|
||||
MessageType.USER,
|
||||
db_session,
|
||||
)
|
||||
|
||||
chat_message.time_sent = row.time_created + timedelta(
|
||||
minutes=random.randint(0, 10)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.commit()
|
||||
@@ -464,12 +464,29 @@ def index_doc_batch(
|
||||
),
|
||||
)
|
||||
|
||||
successful_doc_ids = {record.document_id for record in insertion_records}
|
||||
if successful_doc_ids != set(updatable_ids):
|
||||
all_returned_doc_ids = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Successful IDs: {successful_doc_ids}"
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
)
|
||||
|
||||
last_modified_ids = []
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as standard_oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@@ -322,6 +323,7 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -53,6 +53,11 @@ class Settings(BaseModel):
|
||||
auto_scroll: bool | None = False
|
||||
query_history_type: QueryHistoryType | None = None
|
||||
|
||||
# Image processing settings
|
||||
image_extraction_and_analysis_enabled: bool | None = False
|
||||
search_time_image_analysis_enabled: bool | None = False
|
||||
image_analysis_max_size_mb: int | None = 20
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
|
||||
@@ -47,6 +47,7 @@ def load_settings() -> Settings:
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
45
backend/scripts/chat_history_seeding.py
Normal file
45
backend/scripts/chat_history_seeding.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import argparse
|
||||
import logging
|
||||
from logging import getLogger
|
||||
|
||||
from onyx.db.seeding.chat_history_seeding import seed_chat_history
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format
|
||||
handlers=[logging.StreamHandler()], # Output logs to console
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def go_main(num_sessions: int, num_messages: int, num_days: int) -> None:
|
||||
seed_chat_history(num_sessions, num_messages, num_days)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Seed chat history")
|
||||
parser.add_argument(
|
||||
"--sessions",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of chat sessions to seed",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--messages",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of chat messages to seed per session",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--days",
|
||||
type=int,
|
||||
default=90,
|
||||
help="Number of days looking backwards over which to seed the timestamps with",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
go_main(args.sessions, args.messages, args.days)
|
||||
@@ -45,7 +45,7 @@ def test_confluence_connector_basic(
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 3
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
page_within_a_page_doc: Document | None = None
|
||||
page_doc: Document | None = None
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.seeding.chat_history_seeding import seed_chat_history
|
||||
|
||||
|
||||
def test_usage_reports(reset: None) -> None:
|
||||
EXPECTED_SESSIONS = 2048
|
||||
MESSAGES_PER_SESSION = 4
|
||||
EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION
|
||||
|
||||
seed_chat_history(EXPECTED_SESSIONS, MESSAGES_PER_SESSION, 90)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# count of all entries should be exact
|
||||
period = (
|
||||
datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
count = 0
|
||||
for entry_batch in get_all_empty_chat_message_entries(db_session, period):
|
||||
for entry in entry_batch:
|
||||
count += 1
|
||||
|
||||
assert count == EXPECTED_MESSAGES
|
||||
|
||||
# count in a one month time range should be within a certain range statistically
|
||||
# this can be improved if we seed the chat history data deterministically
|
||||
period = (
|
||||
datetime.now(tz=timezone.utc) - timedelta(days=30),
|
||||
datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
count = 0
|
||||
for entry_batch in get_all_empty_chat_message_entries(db_session, period):
|
||||
for entry in entry_batch:
|
||||
count += 1
|
||||
|
||||
lower = EXPECTED_MESSAGES // 3 - (EXPECTED_MESSAGES // (3 * 3))
|
||||
upper = EXPECTED_MESSAGES // 3 + (EXPECTED_MESSAGES // (3 * 3))
|
||||
assert count > lower
|
||||
assert count < upper
|
||||
@@ -80,3 +80,13 @@ prod cluster**
|
||||
- `kubectl delete -f .`
|
||||
- To not delete the persistent volumes (Document indexes and Users), specify the specific `.yaml` files instead of
|
||||
`.` without specifying delete on persistent-volumes.yaml.
|
||||
|
||||
### Using Helm to deploy to an existing cluster
|
||||
|
||||
Onyx has a helm chart that is convenient to install all services to an existing Kubernetes cluster. To install:
|
||||
|
||||
* Currently the helm chart is not published so to install, clone the repo.
|
||||
* Configure access to the cluster via kubectl. Ensure the kubectl context is set to the cluster that you want to use
|
||||
* The default secrets, environment variables and other service level configuration are stored in `deployment/helm/charts/onyx/values.yml`. You may create another `override.yml`
|
||||
* `cd deployment/helm/charts/onyx` and run `helm install onyx -n onyx -f override.yaml .`. This will install onyx on the cluster under the `onyx` namespace.
|
||||
* Check the status of the deploy using `kubectl get pods -n onyx`
|
||||
27
deployment/helm/charts/onyx/templates/ingress-api.yaml
Normal file
27
deployment/helm/charts/onyx/templates/ingress-api.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-ingress-api
|
||||
annotations:
|
||||
kubernetes.io/ingress.class: nginx
|
||||
nginx.ingress.kubernetes.io/rewrite-target: /$2
|
||||
nginx.ingress.kubernetes.io/use-regex: "true"
|
||||
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
spec:
|
||||
rules:
|
||||
- host: {{ .Values.ingress.api.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /api(/|$)(.*)
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "onyx-stack.fullname" . }}-api-service
|
||||
port:
|
||||
number: {{ .Values.api.service.servicePort }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.api.host }}
|
||||
secretName: {{ include "onyx-stack.fullname" . }}-ingress-api-tls
|
||||
{{- end }}
|
||||
26
deployment/helm/charts/onyx/templates/ingress-webserver.yaml
Normal file
26
deployment/helm/charts/onyx/templates/ingress-webserver.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-ingress-webserver
|
||||
annotations:
|
||||
kubernetes.io/ingress.class: nginx
|
||||
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
kubernetes.io/tls-acme: "true"
|
||||
spec:
|
||||
rules:
|
||||
- host: {{ .Values.ingress.webserver.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "onyx-stack.fullname" . }}-webserver
|
||||
port:
|
||||
number: {{ .Values.webserver.service.servicePort }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.webserver.host }}
|
||||
secretName: {{ include "onyx-stack.fullname" . }}-ingress-webserver-tls
|
||||
{{- end }}
|
||||
20
deployment/helm/charts/onyx/templates/lets-encrypt.yaml
Normal file
20
deployment/helm/charts/onyx/templates/lets-encrypt.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
{{- if .Values.letsencrypt.enabled -}}
|
||||
apiVersion: cert-manager.io/v1
|
||||
kind: ClusterIssuer
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
spec:
|
||||
acme:
|
||||
# The ACME server URL
|
||||
server: https://acme-v02.api.letsencrypt.org/directory
|
||||
# Email address used for ACME registration
|
||||
email: {{ .Values.letsencrypt.email }}
|
||||
# Name of a secret used to store the ACME account private key
|
||||
privateKeySecretRef:
|
||||
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
# Enable the HTTP-01 challenge provider
|
||||
solvers:
|
||||
- http01:
|
||||
ingress:
|
||||
class: nginx
|
||||
{{- end }}
|
||||
@@ -376,22 +376,17 @@ redis:
|
||||
existingSecret: onyx-secrets
|
||||
existingSecretPasswordKey: redis_password
|
||||
|
||||
# ingress:
|
||||
# enabled: false
|
||||
# className: ""
|
||||
# annotations: {}
|
||||
# # kubernetes.io/ingress.class: nginx
|
||||
# # kubernetes.io/tls-acme: "true"
|
||||
# hosts:
|
||||
# - host: chart-example.local
|
||||
# paths:
|
||||
# - path: /
|
||||
# pathType: ImplementationSpecific
|
||||
# tls: []
|
||||
# # - secretName: chart-example-tls
|
||||
# # hosts:
|
||||
# # - chart-example.local
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
api:
|
||||
host: onyx.local
|
||||
webserver:
|
||||
host: onyx.local
|
||||
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
email: "abc@abc.com"
|
||||
|
||||
auth:
|
||||
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets
|
||||
|
||||
@@ -26,7 +26,7 @@ export function Checkbox({
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
}) {
|
||||
return (
|
||||
<label className="flex text-sm cursor-pointer">
|
||||
<label className="flex text-xs cursor-pointer">
|
||||
<input
|
||||
checked={checked}
|
||||
onChange={onChange}
|
||||
@@ -34,7 +34,7 @@ export function Checkbox({
|
||||
className="mr-2 w-3.5 h-3.5 my-auto"
|
||||
/>
|
||||
<div>
|
||||
<Label>{label}</Label>
|
||||
<Label small>{label}</Label>
|
||||
{sublabel && <SubLabel>{sublabel}</SubLabel>}
|
||||
</div>
|
||||
</label>
|
||||
@@ -208,7 +208,7 @@ export function SettingsForm() {
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex flex-col pb-8">
|
||||
{popup}
|
||||
<Title className="mb-4">Workspace Settings</Title>
|
||||
<Checkbox
|
||||
@@ -290,23 +290,71 @@ export function SettingsForm() {
|
||||
id="chatRetentionInput"
|
||||
placeholder="Infinite Retention"
|
||||
/>
|
||||
<Button
|
||||
onClick={handleSetChatRetention}
|
||||
variant="submit"
|
||||
size="sm"
|
||||
className="mr-3"
|
||||
>
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleClearChatRetention}
|
||||
variant="default"
|
||||
size="sm"
|
||||
>
|
||||
Retain All
|
||||
</Button>
|
||||
<div className="mr-auto flex gap-2">
|
||||
<Button
|
||||
onClick={handleSetChatRetention}
|
||||
variant="submit"
|
||||
size="sm"
|
||||
className="mr-auto"
|
||||
>
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleClearChatRetention}
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="mr-auto"
|
||||
>
|
||||
Retain All
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Image Processing Settings */}
|
||||
<Title className="mt-8 mb-4">Image Processing</Title>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
<Checkbox
|
||||
label="Enable Image Extraction and Analysis"
|
||||
sublabel="Extract and analyze images from documents during indexing. This allows the system to process images and create searchable descriptions of them."
|
||||
checked={settings.image_extraction_and_analysis_enabled ?? false}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField(
|
||||
"image_extraction_and_analysis_enabled",
|
||||
e.target.checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
<Checkbox
|
||||
label="Enable Search-time Image Analysis"
|
||||
sublabel="Analyze images at search time when a user asks about images. This provides more detailed and query-specific image analysis but may increase search-time latency."
|
||||
checked={settings.search_time_image_analysis_enabled ?? false}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField(
|
||||
"search_time_image_analysis_enabled",
|
||||
e.target.checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
<IntegerInput
|
||||
label="Maximum Image Size for Analysis (MB)"
|
||||
sublabel="Images larger than this size will not be analyzed to prevent excessive resource usage."
|
||||
value={settings.image_analysis_max_size_mb ?? null}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value ? parseInt(e.target.value) : null;
|
||||
if (value !== null && !isNaN(value) && value > 0) {
|
||||
updateSettingField([
|
||||
{ fieldName: "image_analysis_max_size_mb", newValue: value },
|
||||
]);
|
||||
}
|
||||
}}
|
||||
id="image-analysis-max-size"
|
||||
placeholder="Enter maximum size in MB"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -21,6 +21,11 @@ export interface Settings {
|
||||
auto_scroll: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
query_history_type: QueryHistoryType;
|
||||
|
||||
// Image processing settings
|
||||
image_extraction_and_analysis_enabled?: boolean;
|
||||
search_time_image_analysis_enabled?: boolean;
|
||||
image_analysis_max_size_mb?: number;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
||||
@@ -61,6 +61,7 @@ export function EmailPasswordForm({
|
||||
|
||||
if (!response.ok) {
|
||||
setIsWorking(false);
|
||||
|
||||
const errorDetail = (await response.json()).detail;
|
||||
let errorMsg = "Unknown error";
|
||||
if (typeof errorDetail === "object" && errorDetail.reason) {
|
||||
@@ -96,12 +97,13 @@ export function EmailPasswordForm({
|
||||
} else {
|
||||
setIsWorking(false);
|
||||
const errorDetail = (await loginResponse.json()).detail;
|
||||
|
||||
let errorMsg = "Unknown error";
|
||||
if (errorDetail === "LOGIN_BAD_CREDENTIALS") {
|
||||
errorMsg = "Invalid email or password";
|
||||
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
|
||||
errorMsg = "Create an account to set a password";
|
||||
} else if (typeof errorDetail === "string") {
|
||||
errorMsg = errorDetail;
|
||||
}
|
||||
if (loginResponse.status === 429) {
|
||||
errorMsg = "Too many requests. Please try again later.";
|
||||
|
||||
@@ -95,7 +95,7 @@ const Page = async (props: {
|
||||
</div>
|
||||
)}
|
||||
|
||||
<EmailPasswordFormau
|
||||
<EmailPasswordForm
|
||||
isSignup
|
||||
shouldVerify={authTypeMetadata?.requiresVerification}
|
||||
nextUrl={nextUrl}
|
||||
|
||||
@@ -191,6 +191,7 @@ export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
|
||||
onChange={(e) => setNewFolderName(e.target.value)}
|
||||
className="text-sm font-medium bg-transparent outline-none w-full pb-1 border-b border-background-500 transition-colors duration-200"
|
||||
onKeyDown={(e) => {
|
||||
e.stopPropagation();
|
||||
if (e.key === "Enter") {
|
||||
handleEdit();
|
||||
}
|
||||
|
||||
@@ -303,7 +303,6 @@ const FolderItem = ({
|
||||
key={chatSession.id}
|
||||
chatSession={chatSession}
|
||||
isSelected={chatSession.id === currentChatId}
|
||||
skipGradient={isDragOver}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
|
||||
@@ -32,21 +32,17 @@ export function ChatSessionDisplay({
|
||||
chatSession,
|
||||
search,
|
||||
isSelected,
|
||||
skipGradient,
|
||||
closeSidebar,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
foldersExisting,
|
||||
isDragging,
|
||||
}: {
|
||||
chatSession: ChatSession;
|
||||
isSelected: boolean;
|
||||
search?: boolean;
|
||||
skipGradient?: boolean;
|
||||
closeSidebar?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
foldersExisting?: boolean;
|
||||
isDragging?: boolean;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
@@ -238,8 +234,12 @@ export function ChatSessionDisplay({
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
onChange={(e) => setChatName(e.target.value)}
|
||||
onChange={(e) => {
|
||||
setChatName(e.target.value);
|
||||
}}
|
||||
onKeyDown={(event) => {
|
||||
event.stopPropagation();
|
||||
|
||||
if (event.key === "Enter") {
|
||||
onRename();
|
||||
event.preventDefault();
|
||||
|
||||
@@ -264,7 +264,6 @@ export function PagesTab({
|
||||
>
|
||||
<ChatSessionDisplay
|
||||
chatSession={chat}
|
||||
foldersExisting={foldersExisting}
|
||||
isSelected={currentChatId === chat.id}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
|
||||
@@ -40,8 +40,12 @@ export const ConnectorTitle = ({
|
||||
const typedConnector = connector as Connector<GithubConfig>;
|
||||
additionalMetadata.set(
|
||||
"Repo",
|
||||
typedConnector.connector_specific_config.repo_name
|
||||
? `${typedConnector.connector_specific_config.repo_owner}/${typedConnector.connector_specific_config.repo_name}`
|
||||
typedConnector.connector_specific_config.repositories
|
||||
? `${typedConnector.connector_specific_config.repo_owner}/${
|
||||
typedConnector.connector_specific_config.repositories.includes(",")
|
||||
? "multiple repos"
|
||||
: typedConnector.connector_specific_config.repositories
|
||||
}`
|
||||
: `${typedConnector.connector_specific_config.repo_owner}/*`
|
||||
);
|
||||
} else if (connector.source === "gitlab") {
|
||||
|
||||
@@ -190,10 +190,12 @@ export const connectorConfigs: Record<
|
||||
fields: [
|
||||
{
|
||||
type: "text",
|
||||
query: "Enter the repository name:",
|
||||
label: "Repository Name",
|
||||
name: "repo_name",
|
||||
query: "Enter the repository name(s):",
|
||||
label: "Repository Name(s)",
|
||||
name: "repositories",
|
||||
optional: false,
|
||||
description:
|
||||
"For multiple repositories, enter comma-separated names (e.g., repo1,repo2,repo3)",
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -1358,7 +1360,7 @@ export interface WebConfig {
|
||||
|
||||
export interface GithubConfig {
|
||||
repo_owner: string;
|
||||
repo_name: string;
|
||||
repositories: string; // Comma-separated list of repository names
|
||||
include_prs: boolean;
|
||||
include_issues: boolean;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user