mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
1 Commits
gmail
...
fix_openap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25d9266da4 |
94
.github/workflows/nightly-scan-licenses.yml
vendored
94
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -53,90 +53,24 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@0.29.0
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
# sarif_file: trivy-results.sarif
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
"""improved index
|
||||
|
||||
Revision ID: 3bd4c84fe72f
|
||||
Revises: 8f43500ee275
|
||||
Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
down_revision = "8f43500ee275"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# NOTE:
|
||||
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
# Also add a stored tsvector column for chat_session.description
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -18,13 +18,12 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""force lowercase all users
|
||||
|
||||
Revision ID: f11b408e39d3
|
||||
Revises: 3bd4c84fe72f
|
||||
Create Date: 2025-02-26 17:04:55.683500
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f11b408e39d3"
|
||||
down_revision = "3bd4c84fe72f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing user emails to lowercase
|
||||
from alembic import op
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
|
||||
# 2) Add a check constraint to ensure emails are always lowercase
|
||||
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
from alembic import op
|
||||
|
||||
op.drop_constraint("ensure_lowercase_email", "user", type_="check")
|
||||
@@ -1,42 +0,0 @@
|
||||
"""lowercase multi-tenant user auth
|
||||
|
||||
Revision ID: 34e3630c7f32
|
||||
Revises: a4f6ee863c47
|
||||
Create Date: 2025-02-26 15:03:01.211894
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34e3630c7f32"
|
||||
down_revision = "a4f6ee863c47"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing rows to lowercase
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_tenant_mapping
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
# 2) Add a check constraint so that emails cannot be written in uppercase
|
||||
op.create_check_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
"email = LOWER(email)",
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
op.drop_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
schema="public",
|
||||
type_="check",
|
||||
)
|
||||
@@ -5,9 +5,11 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -16,8 +18,10 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -31,19 +35,24 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -51,9 +60,9 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -2,7 +2,6 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,10 +21,8 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -38,8 +35,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -112,17 +107,6 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -138,7 +122,6 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -158,12 +141,6 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -180,16 +157,11 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=minimal_chat_sessions,
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -200,12 +172,6 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -227,9 +193,6 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -240,12 +203,6 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -256,9 +213,6 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -42,9 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -55,19 +52,8 @@ def fetch_billing_information(
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
||||
@@ -200,35 +200,14 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
@@ -240,6 +219,25 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -36,7 +35,8 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,8 +2,6 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -15,7 +13,6 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -153,9 +150,8 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -177,7 +173,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -243,13 +239,13 @@ def send_user_email_invite(
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User | None = None
|
||||
user: User
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@@ -420,20 +420,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
user = await self.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@@ -444,36 +439,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
@@ -568,7 +553,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -132,9 +131,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
|
||||
@@ -109,7 +109,9 @@ def revoke_tasks_blocking_deletion(
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -222,7 +224,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
@@ -343,7 +345,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -498,7 +500,7 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -538,7 +540,7 @@ def validate_connector_deletion_fences(
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
|
||||
@@ -221,7 +221,7 @@ def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -320,7 +320,7 @@ def try_creating_permissions_sync_task(
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
@@ -410,6 +410,7 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -509,7 +510,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
@@ -584,7 +585,7 @@ def update_external_document_permissions_task(
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -631,7 +632,7 @@ def validate_permission_sync_fences(
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
@@ -841,7 +842,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
|
||||
@@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
@@ -220,7 +220,7 @@ def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -306,7 +306,7 @@ def try_creating_external_group_sync_task(
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
@@ -392,6 +392,7 @@ def connector_external_group_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -493,7 +494,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
@@ -525,7 +526,7 @@ def validate_external_group_sync_fences(
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
|
||||
@@ -182,7 +182,7 @@ class SimpleJobResult:
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str
|
||||
tenant_id: str | None
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_task(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -890,7 +890,7 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
@@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
@@ -311,7 +311,7 @@ def validate_indexing_fence(
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -442,7 +442,7 @@ def try_creating_indexing_task(
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class Metric(BaseModel):
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str) -> None:
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
@@ -656,7 +656,7 @@ def build_job_id(
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
|
||||
@@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
@@ -333,7 +333,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -521,7 +521,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -615,7 +615,7 @@ def validate_pruning_fences(
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
|
||||
@@ -32,7 +32,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
|
||||
@@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -297,8 +297,7 @@ def cloud_beat_task_generator(
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
tenant_ids: list[str] | list[None] = []
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -326,8 +325,6 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -347,7 +344,6 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
@@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -523,7 +523,9 @@ def monitor_document_set_taskset(
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Optional
|
||||
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.setup import setup_logger
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -55,7 +55,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -68,6 +67,7 @@ def _get_connector_runner(
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
@@ -86,6 +86,7 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
@@ -240,7 +241,7 @@ def _check_failure_threshold(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -387,6 +388,7 @@ def _run_indexing(
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
@@ -679,7 +681,7 @@ def _run_indexing(
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
@@ -699,7 +701,7 @@ def run_indexing_entrypoint(
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if MULTI_TENANT:
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -30,9 +29,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
|
||||
@@ -213,12 +213,6 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -163,9 +164,13 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
@@ -179,6 +184,7 @@ def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
@@ -210,6 +216,7 @@ def validate_ccpair_for_user(
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -27,6 +27,8 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -163,10 +165,12 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -175,8 +179,9 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -198,6 +203,8 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import openpyxl # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
@@ -45,15 +43,12 @@ def _extract_sections_basic(
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
if mime_type not in supported_file_types:
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
# ---------------------------
|
||||
# Google Sheets extraction
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
@@ -114,53 +109,7 @@ def _extract_sections_basic(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# ---------------------------
|
||||
# Microsoft Excel (.xlsx or .xls) extraction branch
|
||||
elif mime_type in [
|
||||
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
|
||||
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
|
||||
]:
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
|
||||
tmp.write(response)
|
||||
tmp_path = tmp.name
|
||||
|
||||
section_separator = "\n\n"
|
||||
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
|
||||
|
||||
# Work similarly to the xlsx_to_text function used for file connector
|
||||
# but returns Sections instead of a string
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"Sheet: {sheet.title}\n\n"
|
||||
+ section_separator.join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(
|
||||
min_row=1, values_only=True
|
||||
)
|
||||
if row
|
||||
)
|
||||
),
|
||||
)
|
||||
for sheet in workbook.worksheets
|
||||
]
|
||||
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting data from Excel file '{file['name']}': {e}"
|
||||
)
|
||||
return [
|
||||
Section(link=link, text="Error extracting data from Excel file")
|
||||
]
|
||||
|
||||
# ---------------------------
|
||||
# Export for Google Docs, PPT, and fallback for spreadsheets
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@@ -179,8 +128,6 @@ def _extract_sections_basic(
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
# ---------------------------
|
||||
# Plain text and Markdown files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
@@ -194,8 +141,6 @@ def _extract_sections_basic(
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
@@ -225,11 +170,7 @@ def _extract_sections_basic(
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
error_message = f"Unsupported file type: {mime_type}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
@@ -5,10 +5,6 @@ from typing import Any
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -72,7 +73,7 @@ def insert_api_key(
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
api_key = generate_api_key(tenant_id)
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
|
||||
@@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_updated))
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
@@ -962,7 +962,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
is_agentic=chat_message.is_agentic,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import literal
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import union_all
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -25,87 +26,127 @@ def search_chat_sessions(
|
||||
include_onyxbot_flows: bool = False,
|
||||
) -> Tuple[List[ChatSession], bool]:
|
||||
"""
|
||||
Fast full-text search on ChatSession + ChatMessage using tsvectors.
|
||||
Search for chat sessions based on the provided query.
|
||||
If no query is provided, returns recent chat sessions.
|
||||
|
||||
If no query is provided, returns the most recent chat sessions.
|
||||
Otherwise, searches both chat messages and session descriptions.
|
||||
|
||||
Returns a tuple of (sessions, has_more) where has_more indicates if
|
||||
there are additional results beyond the requested page.
|
||||
Returns a tuple of (chat_sessions, has_more)
|
||||
"""
|
||||
offset_val = (page - 1) * page_size
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# If no query, just return the most recent sessions
|
||||
# If no search query, we use standard SQLAlchemy pagination
|
||||
if not query or not query.strip():
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = select(ChatSession)
|
||||
if user_id:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(offset).limit(page_size + 1)
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
sessions = result.scalars().all()
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
has_more = len(sessions) > page_size
|
||||
has_more = len(chat_sessions) > page_size
|
||||
if has_more:
|
||||
sessions = sessions[:page_size]
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return list(sessions), has_more
|
||||
return list(chat_sessions), has_more
|
||||
|
||||
# Otherwise, proceed with full-text search
|
||||
query = query.strip()
|
||||
words = query.lower().strip().split()
|
||||
|
||||
base_conditions = []
|
||||
if user_id is not None:
|
||||
base_conditions.append(ChatSession.user_id == user_id)
|
||||
# Message mach subquery
|
||||
message_matches = []
|
||||
for word in words:
|
||||
word_like = f"%{word}%"
|
||||
message_match: Select = (
|
||||
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
|
||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.where(func.lower(ChatMessage.message).like(word_like))
|
||||
)
|
||||
|
||||
if user_id:
|
||||
message_match = message_match.where(ChatSession.user_id == user_id)
|
||||
|
||||
message_matches.append(message_match)
|
||||
|
||||
if message_matches:
|
||||
message_matches_query = union_all(*message_matches).alias("message_matches")
|
||||
else:
|
||||
return [], False
|
||||
|
||||
# Description matches
|
||||
description_match: Select = select(
|
||||
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
|
||||
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
|
||||
|
||||
if user_id:
|
||||
description_match = description_match.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
|
||||
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
base_conditions.append(ChatSession.deleted.is_(False))
|
||||
description_match = description_match.where(ChatSession.deleted.is_(False))
|
||||
|
||||
message_tsv: ColumnClause = column("message_tsv")
|
||||
description_tsv: ColumnClause = column("description_tsv")
|
||||
# Combine all match sources
|
||||
combined_matches = union_all(
|
||||
message_matches_query.select(), description_match
|
||||
).alias("combined_matches")
|
||||
|
||||
ts_query = func.plainto_tsquery("english", query)
|
||||
|
||||
description_session_ids = (
|
||||
select(ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(description_tsv.op("@@")(ts_query))
|
||||
# Use CTE to group and get max rank
|
||||
session_ranks = (
|
||||
select(
|
||||
combined_matches.c.chat_session_id,
|
||||
func.max(combined_matches.c.search_rank).label("rank"),
|
||||
)
|
||||
.group_by(combined_matches.c.chat_session_id)
|
||||
.alias("session_ranks")
|
||||
)
|
||||
|
||||
message_session_ids = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(message_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
combined_ids = description_session_ids.union(message_session_ids).alias(
|
||||
"combined_ids"
|
||||
)
|
||||
|
||||
final_stmt = (
|
||||
select(ChatSession)
|
||||
.join(combined_ids, ChatSession.id == combined_ids.c.id)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.distinct()
|
||||
.offset(offset_val)
|
||||
# Get ranked sessions with pagination
|
||||
ranked_query = (
|
||||
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
|
||||
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
|
||||
.offset(offset)
|
||||
.limit(page_size + 1)
|
||||
.options(joinedload(ChatSession.persona))
|
||||
)
|
||||
|
||||
session_objs = db_session.execute(final_stmt).scalars().all()
|
||||
result = ranked_query.all()
|
||||
|
||||
has_more = len(session_objs) > page_size
|
||||
# Extract session IDs and ranks
|
||||
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
|
||||
session_ids = list(session_ids_with_ranks.keys())
|
||||
|
||||
if not session_ids:
|
||||
return [], False
|
||||
|
||||
# Now, let's query the actual ChatSession objects using the IDs
|
||||
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
# Full objects with eager loading
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
# Sort based on above ranking
|
||||
chat_sessions = sorted(
|
||||
chat_sessions,
|
||||
key=lambda session: (
|
||||
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
|
||||
session.time_created.timestamp() * -1, # Then by time (newest first)
|
||||
),
|
||||
)
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
if has_more:
|
||||
session_objs = session_objs[:page_size]
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return list(session_objs), has_more
|
||||
return chat_sessions, has_more
|
||||
|
||||
@@ -360,13 +360,18 @@ def backend_update_credential_json(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _delete_credential_internal(
|
||||
credential: Credential,
|
||||
def delete_credential(
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Internal utility function to handle the actual deletion of a credential"""
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
associated_connectors = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
||||
@@ -411,35 +416,6 @@ def _delete_credential_internal(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential_for_user(
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential that belongs to a specific user"""
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def delete_credential(
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential regardless of ownership (admin function)"""
|
||||
credential = fetch_credential_by_id(credential_id, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(f"Credential by provided id {credential_id} does not exist")
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
|
||||
@@ -258,11 +258,11 @@ class SqlEngine:
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [POSTGRES_DEFAULT_SCHEMA]
|
||||
return [None]
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
result = session.execute(
|
||||
@@ -417,7 +417,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import validates
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@@ -26,7 +25,6 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
@@ -207,10 +205,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
@property
|
||||
def password_configured(self) -> bool:
|
||||
"""
|
||||
@@ -2275,10 +2269,6 @@ class UserTenantMapping(Base):
|
||||
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
|
||||
@@ -100,14 +100,9 @@ def _add_user_filters(
|
||||
.correlate(Persona)
|
||||
)
|
||||
else:
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
where_clause |= Persona.is_public == True # noqa: E712
|
||||
where_clause &= Persona.is_visible == True # noqa: E712
|
||||
where_clause |= Persona__User.user_id == user.id
|
||||
|
||||
where_clause |= Persona.user_id == user.id
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
# Vespa's Document API.
|
||||
def get_document_chunk_ids(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
|
||||
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
)
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if MULTI_TENANT:
|
||||
if tenant_id and MULTI_TENANT:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
@@ -43,7 +43,7 @@ class IndexBatchParams:
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str
|
||||
tenant_id: str | None
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
@@ -270,7 +270,9 @@ class Updatable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
|
||||
@@ -468,7 +468,9 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
@@ -616,7 +618,7 @@ class VespaIndex(DocumentIndex):
|
||||
doc_id: str,
|
||||
*,
|
||||
chunk_count: int | None,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
@@ -659,7 +661,7 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
total_chunks_deleted = 0
|
||||
|
||||
@@ -158,8 +158,8 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -317,8 +317,8 @@ def index_doc_batch(
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
@@ -525,9 +525,9 @@ def build_indexing_pipeline(
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
|
||||
@@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tenant_id: str | None = None
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
@@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
|
||||
@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
setup_onyx(db_session, None)
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_session_by_message_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
@@ -410,11 +410,12 @@ def _build_qa_response_blocks(
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
@@ -481,6 +482,7 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
@@ -515,6 +517,7 @@ def build_slack_response_blocks(
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
@@ -114,7 +114,7 @@ def handle_generate_answer_button(
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -136,6 +136,7 @@ def handle_generate_answer_button(
|
||||
slack_channel_config=slack_channel_config,
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
tenant_id=client.tenant_id,
|
||||
channel=channel_id,
|
||||
logger=logger,
|
||||
feedback_reminder_id=None,
|
||||
@@ -150,10 +151,11 @@ def handle_slack_feedback(
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
@@ -244,7 +246,7 @@ def handle_followup_button(
|
||||
|
||||
tag_ids: list[str] = []
|
||||
group_ids: list[str] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
channel_name, is_dm = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel_id
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -109,6 +109,7 @@ def handle_message(
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -134,7 +135,9 @@ def handle_message(
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
slack_usage_report(
|
||||
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
@@ -215,7 +218,7 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if message_info.email:
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
@@ -241,5 +244,6 @@ def handle_message(
|
||||
channel=channel,
|
||||
logger=logger,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
return issue_with_regular_answer
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -71,6 +72,7 @@ def handle_regular_answer(
|
||||
channel: str,
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
@@ -85,7 +87,7 @@ def handle_regular_answer(
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
@@ -94,7 +96,7 @@ def handle_regular_answer(
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona
|
||||
if not persona:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
@@ -155,7 +157,7 @@ def handle_regular_answer(
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, onyx_user: User | None
|
||||
) -> ChatOnyxBotResponse:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=onyx_user,
|
||||
@@ -195,7 +197,7 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
@@ -359,6 +361,7 @@ def handle_regular_answer(
|
||||
return True
|
||||
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
channel_conf=channel_conf,
|
||||
|
||||
@@ -37,7 +37,6 @@ from onyx.context.search.retrieval.search_runner import (
|
||||
download_nltk_data,
|
||||
)
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -93,7 +92,6 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -125,13 +123,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str] = set()
|
||||
self.tenant_ids: Set[str | None] = set()
|
||||
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
|
||||
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str, Lock] = {}
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
@@ -195,7 +193,7 @@ class SlackbotHandler:
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str, bot: SlackBot
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
@@ -349,7 +347,7 @@ class SlackbotHandler:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Attempt to fetch Slack bots
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
@@ -387,7 +385,7 @@ class SlackbotHandler:
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str) -> None:
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
@@ -417,7 +415,7 @@ class SlackbotHandler:
|
||||
)
|
||||
|
||||
def start_socket_client(
|
||||
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
socket_client: TenantSocketModeClient = _get_socket_client(
|
||||
slack_bot_tokens, tenant_id, slack_bot_id
|
||||
@@ -588,7 +586,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -682,6 +680,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
|
||||
user_id_to_post_confirmation=user_id,
|
||||
channel_id_to_post_confirmation=channel_id,
|
||||
thread_ts_to_post_confirmation=thread_ts,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_action_id(feedback_id)
|
||||
@@ -797,9 +796,8 @@ def process_message(
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
)
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
@@ -812,39 +810,50 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
token: Token[str | None] | None = None
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
follow_up = bool(
|
||||
slack_channel_config.channel_config
|
||||
and slack_channel_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
follow_up = bool(
|
||||
slack_channel_config.channel_config
|
||||
and slack_channel_config.channel_config.get("follow_up_tags")
|
||||
is not None
|
||||
)
|
||||
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_channel_config=slack_channel_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_channel_config=slack_channel_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender_id,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender_id,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@@ -903,7 +912,7 @@ def create_process_slack_event() -> (
|
||||
|
||||
|
||||
def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
@@ -4,8 +4,6 @@ import re
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -32,7 +30,7 @@ from onyx.configs.onyxbot_configs import (
|
||||
)
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -45,7 +43,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.text_processing import replace_whitespaces_w_space
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -572,7 +569,9 @@ def read_slack_thread(
|
||||
return thread_messages
|
||||
|
||||
|
||||
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
|
||||
def slack_usage_report(
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
|
||||
) -> None:
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
@@ -584,13 +583,14 @@ def slack_usage_report(action: str, sender_id: str | None, client: WebClient) ->
|
||||
logger.warning("Unable to find sender email")
|
||||
|
||||
if sender_email is not None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.USAGE,
|
||||
data={"action": action},
|
||||
user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -663,30 +663,9 @@ def get_feedback_visibility() -> FeedbackVisibility:
|
||||
|
||||
|
||||
class TenantSocketModeClient(SocketModeClient):
|
||||
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
|
||||
def __init__(
|
||||
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._tenant_id = tenant_id
|
||||
self.tenant_id = tenant_id
|
||||
self.slack_bot_id = slack_bot_id
|
||||
|
||||
@contextmanager
|
||||
def _set_tenant_context(self) -> Generator[None, None, None]:
|
||||
token = None
|
||||
try:
|
||||
if self._tenant_id:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self._tenant_id)
|
||||
yield
|
||||
finally:
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def enqueue_message(self, message: str) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().enqueue_message(message)
|
||||
|
||||
def process_message(self) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().process_message()
|
||||
|
||||
def run_message_listeners(self, message: dict, raw_message: str) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().run_message_listeners(message, raw_message)
|
||||
|
||||
@@ -16,10 +16,10 @@ class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str = tenant_id
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id: int = id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
@@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""We can limit the number of tasks generated here, which is useful to prevent
|
||||
one tenant from overwhelming the sync queue.
|
||||
|
||||
@@ -39,8 +39,8 @@ class RedisConnectorDelete:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPermissionSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,12 +52,12 @@ class RedisConnectorIndex:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPrune:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class RedisConnectorStop:
|
||||
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
|
||||
TIMEOUT_TTL = 300
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id: int = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
document set up to date over multiple batches.
|
||||
|
||||
@@ -14,8 +14,8 @@ class RedisObjectHelper(ABC):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: str):
|
||||
self._tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: str):
|
||||
self._tenant_id: str | None = tenant_id
|
||||
self._id: str = id
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -87,7 +87,7 @@ class RedisObjectHelper(ABC):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
@@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
user group up to date over multiple batches.
|
||||
|
||||
@@ -37,15 +37,13 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_indexable_chunks(
|
||||
preprocessed_docs: list[dict],
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
|
||||
ids_to_documents = {}
|
||||
chunks = []
|
||||
@@ -88,7 +86,7 @@ def _create_indexable_chunks(
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=preprocessed_doc["title_embedding"],
|
||||
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
|
||||
tenant_id=tenant_id,
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
boost=DEFAULT_BOOST,
|
||||
@@ -113,7 +111,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -620,7 +620,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
try:
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session)
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
@@ -646,6 +646,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValidationError as e:
|
||||
# If validation fails, delete the connector and commit the changes
|
||||
# Ensures we don't leave invalid connectors in the database
|
||||
@@ -659,14 +660,10 @@ def associate_credential_to_connector(
|
||||
)
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
delete_connector(db_session, connector_id)
|
||||
db_session.commit()
|
||||
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
|
||||
raise HTTPException(status_code=500, detail="Unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -902,6 +902,7 @@ def create_connector_with_mock_credential(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -13,12 +13,12 @@ from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
|
||||
from onyx.db.credentials import delete_credential
|
||||
from onyx.db.credentials import delete_credential_for_user
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
from onyx.db.credentials import swap_credentials_connector
|
||||
from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
@@ -89,7 +89,7 @@ def delete_credential_by_id_admin(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
"""Same as the user endpoint, but can delete any credential (not just the user's own)"""
|
||||
delete_credential(db_session=db_session, credential_id=credential_id)
|
||||
delete_credential(db_session=db_session, credential_id=credential_id, user=None)
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
@@ -100,11 +100,13 @@ def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
validate_ccpair_for_user(
|
||||
credential_swap_req.connector_id,
|
||||
credential_swap_req.new_credential_id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
@@ -243,7 +245,7 @@ def delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential_for_user(
|
||||
delete_credential(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
@@ -260,7 +262,7 @@ def force_delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential_for_user(credential_id, user, db_session, True)
|
||||
delete_credential(credential_id, user, db_session, True)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
|
||||
@@ -49,7 +49,6 @@ def get_folders(
|
||||
name=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
time_created=chat_session.time_created.isoformat(),
|
||||
time_updated=chat_session.time_updated.isoformat(),
|
||||
shared_status=chat_session.shared_status,
|
||||
folder_id=folder.id,
|
||||
)
|
||||
|
||||
@@ -147,11 +147,9 @@ def list_threads(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
|
||||
@@ -119,7 +119,6 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
|
||||
@@ -181,7 +181,6 @@ class ChatSessionDetails(BaseModel):
|
||||
name: str | None
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
time_updated: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
@@ -242,7 +241,6 @@ class ChatMessageDetail(BaseModel):
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None = None
|
||||
is_agentic: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
|
||||
@@ -159,7 +159,6 @@ def get_user_search_sessions(
|
||||
name=sessions_with_documents_dict[search.id],
|
||||
persona_id=search.persona_id,
|
||||
time_created=search.time_created.isoformat(),
|
||||
time_updated=search.time_updated.isoformat(),
|
||||
shared_status=search.shared_status,
|
||||
folder_id=search.folder_id,
|
||||
current_alternate_model=search.current_alternate_model,
|
||||
|
||||
@@ -4,9 +4,7 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
@@ -51,10 +49,9 @@ class Settings(BaseModel):
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
auto_scroll: bool | None = False
|
||||
query_history_type: QueryHistoryType | None = None
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
needs_reindexing: bool
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
tenant_id: str | None = None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -46,7 +45,6 @@ def load_settings() -> Settings:
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def setup_onyx(
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Setup Onyx for a particular tenant. In the Single Tenant case, it will set it up for the default schema
|
||||
|
||||
@@ -260,7 +260,7 @@ def get_documents_for_tenant_connector(
|
||||
def search_for_document(
|
||||
index_name: str,
|
||||
document_id: str | None = None,
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
tenant_id: str | None = None,
|
||||
max_hits: int | None = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
yql_query = f"select * from sources {index_name}"
|
||||
@@ -507,9 +507,9 @@ def get_number_of_chunks_we_think_exist(
|
||||
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
||||
def __init__(self, tenant_id: str | None = None):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
self.tenant_id = tenant_id
|
||||
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
|
||||
self.index_name = get_index_name(self.tenant_id)
|
||||
|
||||
def sample_document_counts(self) -> None:
|
||||
@@ -603,7 +603,7 @@ class VespaDebugging:
|
||||
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
|
||||
|
||||
def search_for_document(
|
||||
self, document_id: str | None = None, tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
self, document_id: str | None = None, tenant_id: str | None = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id, tenant_id)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Modify sys.path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -75,7 +74,7 @@ def _unsafe_deletion(
|
||||
for document in documents:
|
||||
document_index.delete_single(
|
||||
doc_id=document.id,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
tenant_id=None,
|
||||
chunk_count=document.chunk_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -97,9 +96,7 @@ def main() -> None:
|
||||
try:
|
||||
print(f"Deleting document {doc_id} in Vespa")
|
||||
chunks_deleted = vespa_index.delete_single(
|
||||
doc_id,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
chunk_count=document.chunk_count,
|
||||
doc_id, tenant_id=None, chunk_count=document.chunk_count
|
||||
)
|
||||
if chunks_deleted > 0:
|
||||
print(
|
||||
|
||||
@@ -18,7 +18,5 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
def get_current_tenant_id() -> str:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id is None:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
raise RuntimeError("Tenant ID is not set. This should never happen.")
|
||||
return tenant_id
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_confluence_connector_basic(
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert txt_doc.primary_owners[0].email == "chris@danswer.ai"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
|
||||
89
web/package-lock.json
generated
89
web/package-lock.json
generated
@@ -70,8 +70,6 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
@@ -11743,54 +11741,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz",
|
||||
"integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA=="
|
||||
},
|
||||
"node_modules/hast-util-sanitize": {
|
||||
"version": "5.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-sanitize/-/hast-util-sanitize-5.0.2.tgz",
|
||||
"integrity": "sha512-3yTWghByc50aGS7JlGhk61SPenfE/p1oaFeNwkOOyrscaOkMGrcW9+Cy/QAIOBpZxP1yqDIzFMR0+Np0i0+usg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@ungap/structured-clone": "^1.0.0",
|
||||
"unist-util-position": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-html/-/hast-util-to-html-9.0.5.tgz",
|
||||
"integrity": "sha512-OguPdidb+fbHQSU4Q4ZiLKnzWo8Wwsf5bZfbvu7//a9oTYoqD/fWpe96NuHkoS9h0ccGOTe0C4NGXdtS0iObOw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/unist": "^3.0.0",
|
||||
"ccount": "^2.0.0",
|
||||
"comma-separated-tokens": "^2.0.0",
|
||||
"hast-util-whitespace": "^3.0.0",
|
||||
"html-void-elements": "^3.0.0",
|
||||
"mdast-util-to-hast": "^13.0.0",
|
||||
"property-information": "^7.0.0",
|
||||
"space-separated-tokens": "^2.0.0",
|
||||
"stringify-entities": "^4.0.0",
|
||||
"zwitch": "^2.0.4"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html/node_modules/property-information": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.0.0.tgz",
|
||||
"integrity": "sha512-7D/qOz/+Y4X/rzSB6jKxKUsQnphO046ei8qxG59mtM3RG3DHgTK81HrxrmoDVINJb8NKT5ZsRbwHvQ6B68Iyhg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-jsx-runtime": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.0.tgz",
|
||||
@@ -11969,16 +11919,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/html-void-elements": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
|
||||
"integrity": "sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/html-webpack-plugin": {
|
||||
"version": "5.6.3",
|
||||
"resolved": "https://registry.npmjs.org/html-webpack-plugin/-/html-webpack-plugin-5.6.3.tgz",
|
||||
@@ -19185,35 +19125,6 @@
|
||||
"unist-util-visit": "^5.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-sanitize": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/rehype-sanitize/-/rehype-sanitize-6.0.0.tgz",
|
||||
"integrity": "sha512-CsnhKNsyI8Tub6L4sm5ZFsme4puGfc6pYylvXo1AeqaGbjOYyzNv3qZPwvs0oMJ39eryyeOdmxwUIo94IpEhqg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-sanitize": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-stringify": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/rehype-stringify/-/rehype-stringify-10.0.1.tgz",
|
||||
"integrity": "sha512-k9ecfXHmIPuFVI61B9DeLPN0qFHfawM6RsuX48hoqlaKSF61RskNjSm1lI8PhBEM0MRdLxVVm4WmTqJQccH9mA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-to-html": "^9.0.0",
|
||||
"unified": "^11.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/relateurl": {
|
||||
"version": "0.2.7",
|
||||
"resolved": "https://registry.npmjs.org/relateurl/-/relateurl-0.2.7.tgz",
|
||||
|
||||
@@ -73,8 +73,6 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import * as Yup from "yup";
|
||||
import { useRouter } from "next/navigation";
|
||||
@@ -17,18 +17,13 @@ import {
|
||||
GoogleDriveCredentialJson,
|
||||
GoogleDriveServiceAccountCredentialJson,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
|
||||
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
export const DriveJsonUpload = ({
|
||||
setPopup,
|
||||
onSuccess,
|
||||
}: {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
@@ -67,6 +62,7 @@ export const DriveJsonUpload = ({
|
||||
<Button
|
||||
disabled={!credentialJsonStr}
|
||||
onClick={async () => {
|
||||
// check if the JSON is a app credential or a service account credential
|
||||
let credentialFileType: GoogleDriveCredentialJsonTypes;
|
||||
try {
|
||||
const appCredentialJson = JSON.parse(credentialJsonStr!);
|
||||
@@ -103,10 +99,6 @@ export const DriveJsonUpload = ({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
@@ -114,6 +106,7 @@ export const DriveJsonUpload = ({
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
mutate("/api/manage/admin/connector/google-drive/app-credential");
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
@@ -129,22 +122,19 @@ export const DriveJsonUpload = ({
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
@@ -159,7 +149,6 @@ interface DriveJsonUploadSectionProps {
|
||||
appCredentialData?: { client_id: string };
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
}
|
||||
|
||||
export const DriveJsonUploadSection = ({
|
||||
@@ -167,36 +156,17 @@ export const DriveJsonUploadSection = ({
|
||||
appCredentialData,
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountCredentialData
|
||||
);
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountCredentialData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
}, [serviceAccountCredentialData, appCredentialData]);
|
||||
|
||||
const handleSuccess = () => {
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
} else {
|
||||
refreshAllGoogleData(ValidSources.GoogleDrive);
|
||||
}
|
||||
};
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
if (serviceAccountCredentialData?.service_account_email) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
{serviceAccountCredentialData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
@@ -218,15 +188,11 @@ export const DriveJsonUploadSection = ({
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
router.refresh();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
@@ -250,12 +216,12 @@ export const DriveJsonUploadSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
if (appCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
<p className="italic mt-1">{appCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
@@ -276,15 +242,10 @@ export const DriveJsonUploadSection = ({
|
||||
mutate(
|
||||
"/api/manage/admin/connector/google-drive/app-credential"
|
||||
);
|
||||
mutate(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
|
||||
);
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
@@ -336,7 +297,7 @@ export const DriveJsonUploadSection = ({
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if chooosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
<DriveJsonUpload setPopup={setPopup} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -387,41 +348,13 @@ export const DriveAuthSection = ({
|
||||
appCredentialData,
|
||||
setPopup,
|
||||
refreshCredentials,
|
||||
connectorAssociated,
|
||||
connectorAssociated, // don't allow revoke if a connector / credential pair is active with the uploaded credential
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountKeyData
|
||||
);
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
const [
|
||||
localGoogleDrivePublicCredential,
|
||||
setLocalGoogleDrivePublicCredential,
|
||||
] = useState(googleDrivePublicUploadedCredential);
|
||||
const [
|
||||
localGoogleDriveServiceAccountCredential,
|
||||
setLocalGoogleDriveServiceAccountCredential,
|
||||
] = useState(googleDriveServiceAccountCredential);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountKeyData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
setLocalGoogleDrivePublicCredential(googleDrivePublicUploadedCredential);
|
||||
setLocalGoogleDriveServiceAccountCredential(
|
||||
googleDriveServiceAccountCredential
|
||||
);
|
||||
}, [
|
||||
serviceAccountKeyData,
|
||||
appCredentialData,
|
||||
googleDrivePublicUploadedCredential,
|
||||
googleDriveServiceAccountCredential,
|
||||
]);
|
||||
|
||||
const existingCredential =
|
||||
localGoogleDrivePublicCredential ||
|
||||
localGoogleDriveServiceAccountCredential;
|
||||
googleDrivePublicUploadedCredential || googleDriveServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<>
|
||||
@@ -444,7 +377,7 @@ export const DriveAuthSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
if (serviceAccountKeyData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
@@ -505,7 +438,7 @@ export const DriveAuthSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
if (appCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
import React, { useEffect, useState } from "react";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
@@ -14,17 +15,22 @@ import {
|
||||
GoogleDriveCredentialJson,
|
||||
GoogleDriveServiceAccountCredentialJson,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { ConnectorSnapshot } from "@/lib/connectors/connectors";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import {
|
||||
useGoogleAppCredential,
|
||||
useGoogleServiceAccountKey,
|
||||
useGoogleCredentials,
|
||||
useConnectorsByCredentialId,
|
||||
checkCredentialsFetched,
|
||||
filterUploadedCredentials,
|
||||
checkConnectorsExist,
|
||||
refreshAllGoogleData,
|
||||
} from "@/lib/googleConnector";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
|
||||
const useConnectorsByCredentialId = (credential_id: number | null) => {
|
||||
let url: string | null = null;
|
||||
if (credential_id !== null) {
|
||||
url = `/api/manage/admin/connector?credential=${credential_id}`;
|
||||
}
|
||||
const swrResponse = useSWR<ConnectorSnapshot[]>(url, errorHandlingFetcher);
|
||||
|
||||
return {
|
||||
...swrResponse,
|
||||
refreshConnectorsByCredentialId: () => mutate(url),
|
||||
};
|
||||
};
|
||||
|
||||
const GDriveMain = ({
|
||||
setPopup,
|
||||
@@ -33,20 +39,27 @@ const GDriveMain = ({
|
||||
}) => {
|
||||
const { isAdmin, user } = useUser();
|
||||
|
||||
// Get app credential and service account key
|
||||
// tries getting the uploaded credential json
|
||||
const {
|
||||
data: appCredentialData,
|
||||
isLoading: isAppCredentialLoading,
|
||||
error: isAppCredentialError,
|
||||
} = useGoogleAppCredential("google_drive");
|
||||
} = useSWR<{ client_id: string }, FetchError>(
|
||||
"/api/manage/admin/connector/google-drive/app-credential",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
// tries getting the uploaded service account key
|
||||
const {
|
||||
data: serviceAccountKeyData,
|
||||
isLoading: isServiceAccountKeyLoading,
|
||||
error: isServiceAccountKeyError,
|
||||
} = useGoogleServiceAccountKey("google_drive");
|
||||
} = useSWR<{ service_account_email: string }, FetchError>(
|
||||
"/api/manage/admin/connector/google-drive/service-account-key",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
// Get all public credentials
|
||||
// gets all public credentials
|
||||
const {
|
||||
data: credentialsData,
|
||||
isLoading: isCredentialsLoading,
|
||||
@@ -54,19 +67,33 @@ const GDriveMain = ({
|
||||
refreshCredentials,
|
||||
} = usePublicCredentials();
|
||||
|
||||
// Get Google Drive-specific credentials
|
||||
// gets all credentials for source type google drive
|
||||
const {
|
||||
data: googleDriveCredentials,
|
||||
isLoading: isGoogleDriveCredentialsLoading,
|
||||
error: googleDriveCredentialsError,
|
||||
} = useGoogleCredentials(ValidSources.GoogleDrive);
|
||||
|
||||
// Filter uploaded credentials and get credential ID
|
||||
const { credential_id, uploadedCredentials } = filterUploadedCredentials(
|
||||
googleDriveCredentials
|
||||
} = useSWR<Credential<any>[]>(
|
||||
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive),
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 }
|
||||
);
|
||||
|
||||
// Get connectors for the credential ID
|
||||
// filters down to just credentials that were created via upload (there should be only one)
|
||||
let credential_id = null;
|
||||
if (googleDriveCredentials) {
|
||||
const googleDriveUploadedCredentials: Credential<GoogleDriveCredentialJson>[] =
|
||||
googleDriveCredentials.filter(
|
||||
(googleDriveCredential) =>
|
||||
googleDriveCredential.credential_json.authentication_method !==
|
||||
"oauth_interactive"
|
||||
);
|
||||
|
||||
if (googleDriveUploadedCredentials.length > 0) {
|
||||
credential_id = googleDriveUploadedCredentials[0].id;
|
||||
}
|
||||
}
|
||||
|
||||
// retrieves all connectors for that credential id
|
||||
const {
|
||||
data: googleDriveConnectors,
|
||||
isLoading: isGoogleDriveConnectorsLoading,
|
||||
@@ -74,25 +101,13 @@ const GDriveMain = ({
|
||||
refreshConnectorsByCredentialId,
|
||||
} = useConnectorsByCredentialId(credential_id);
|
||||
|
||||
// Check if credentials were successfully fetched
|
||||
const {
|
||||
appCredentialSuccessfullyFetched,
|
||||
serviceAccountKeySuccessfullyFetched,
|
||||
} = checkCredentialsFetched(
|
||||
appCredentialData,
|
||||
isAppCredentialError,
|
||||
serviceAccountKeyData,
|
||||
isServiceAccountKeyError
|
||||
);
|
||||
const appCredentialSuccessfullyFetched =
|
||||
appCredentialData ||
|
||||
(isAppCredentialError && isAppCredentialError.status === 404);
|
||||
const serviceAccountKeySuccessfullyFetched =
|
||||
serviceAccountKeyData ||
|
||||
(isServiceAccountKeyError && isServiceAccountKeyError.status === 404);
|
||||
|
||||
// Handle refresh of all data
|
||||
const handleRefresh = () => {
|
||||
refreshCredentials();
|
||||
refreshConnectorsByCredentialId();
|
||||
refreshAllGoogleData(ValidSources.GoogleDrive);
|
||||
};
|
||||
|
||||
// Loading state
|
||||
if (
|
||||
(!appCredentialSuccessfullyFetched && isAppCredentialLoading) ||
|
||||
(!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) ||
|
||||
@@ -107,7 +122,6 @@ const GDriveMain = ({
|
||||
);
|
||||
}
|
||||
|
||||
// Error states
|
||||
if (credentialsError || !credentialsData) {
|
||||
return <ErrorCallout errorTitle="Failed to load credentials." />;
|
||||
}
|
||||
@@ -127,16 +141,7 @@ const GDriveMain = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (googleDriveConnectorsError) {
|
||||
return (
|
||||
<ErrorCallout errorTitle="Failed to load Google Drive associated connectors." />
|
||||
);
|
||||
}
|
||||
|
||||
// Check if connectors exist
|
||||
const connectorAssociated = checkConnectorsExist(googleDriveConnectors);
|
||||
|
||||
// Get the uploaded OAuth credential
|
||||
// get the actual uploaded oauth or service account credentials
|
||||
const googleDrivePublicUploadedCredential:
|
||||
| Credential<GoogleDriveCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
@@ -147,7 +152,6 @@ const GDriveMain = ({
|
||||
credential.credential_json.authentication_method !== "oauth_interactive"
|
||||
);
|
||||
|
||||
// Get the service account credential
|
||||
const googleDriveServiceAccountCredential:
|
||||
| Credential<GoogleDriveServiceAccountCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
@@ -156,6 +160,19 @@ const GDriveMain = ({
|
||||
credential.source === "google_drive"
|
||||
);
|
||||
|
||||
if (googleDriveConnectorsError) {
|
||||
return (
|
||||
<ErrorCallout errorTitle="Failed to load Google Drive associated connectors." />
|
||||
);
|
||||
}
|
||||
|
||||
let connectorAssociated = false;
|
||||
if (googleDriveConnectors) {
|
||||
if (googleDriveConnectors.length > 0) {
|
||||
connectorAssociated = true;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Title className="mb-2 mt-6">Step 1: Provide your Credentials</Title>
|
||||
@@ -164,30 +181,27 @@ const GDriveMain = ({
|
||||
appCredentialData={appCredentialData}
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
/>
|
||||
|
||||
{isAdmin &&
|
||||
(appCredentialData?.client_id ||
|
||||
serviceAccountKeyData?.service_account_email) && (
|
||||
<>
|
||||
<Title className="mb-2 mt-6">Step 2: Authenticate with Onyx</Title>
|
||||
<DriveAuthSection
|
||||
setPopup={setPopup}
|
||||
refreshCredentials={handleRefresh}
|
||||
googleDrivePublicUploadedCredential={
|
||||
googleDrivePublicUploadedCredential
|
||||
}
|
||||
googleDriveServiceAccountCredential={
|
||||
googleDriveServiceAccountCredential
|
||||
}
|
||||
appCredentialData={appCredentialData}
|
||||
serviceAccountKeyData={serviceAccountKeyData}
|
||||
connectorAssociated={connectorAssociated}
|
||||
user={user}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{isAdmin && (
|
||||
<>
|
||||
<Title className="mb-2 mt-6">Step 2: Authenticate with Onyx</Title>
|
||||
<DriveAuthSection
|
||||
setPopup={setPopup}
|
||||
refreshCredentials={refreshCredentials}
|
||||
googleDrivePublicUploadedCredential={
|
||||
googleDrivePublicUploadedCredential
|
||||
}
|
||||
googleDriveServiceAccountCredential={
|
||||
googleDriveServiceAccountCredential
|
||||
}
|
||||
appCredentialData={appCredentialData}
|
||||
serviceAccountKeyData={serviceAccountKeyData}
|
||||
connectorAssociated={connectorAssociated}
|
||||
user={user}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Button } from "@/components/Button";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import * as Yup from "yup";
|
||||
import { useRouter } from "next/navigation";
|
||||
@@ -17,18 +17,13 @@ import {
|
||||
GmailCredentialJson,
|
||||
GmailServiceAccountCredentialJson,
|
||||
} from "@/lib/connectors/credentials";
|
||||
import { refreshAllGoogleData } from "@/lib/googleConnector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
|
||||
type GmailCredentialJsonTypes = "authorized_user" | "service_account";
|
||||
|
||||
const DriveJsonUpload = ({
|
||||
setPopup,
|
||||
onSuccess,
|
||||
}: {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
onSuccess?: () => void;
|
||||
}) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
@@ -77,7 +72,7 @@ const DriveJsonUpload = ({
|
||||
credentialFileType = "service_account";
|
||||
} else {
|
||||
throw new Error(
|
||||
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
|
||||
"Unknown credential type, expected 'OAuth Web application'"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
@@ -104,10 +99,6 @@ const DriveJsonUpload = ({
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
@@ -115,6 +106,7 @@ const DriveJsonUpload = ({
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
}
|
||||
|
||||
if (credentialFileType === "service_account") {
|
||||
@@ -130,20 +122,17 @@ const DriveJsonUpload = ({
|
||||
);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Successfully uploaded service account key",
|
||||
message: "Successfully uploaded app credentials",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
}
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to upload service account key - ${errorMsg}`,
|
||||
message: `Failed to upload app credentials - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
mutate("/api/manage/admin/connector/gmail/service-account-key");
|
||||
}
|
||||
}}
|
||||
>
|
||||
@@ -158,7 +147,6 @@ interface DriveJsonUploadSectionProps {
|
||||
appCredentialData?: { client_id: string };
|
||||
serviceAccountCredentialData?: { service_account_email: string };
|
||||
isAdmin: boolean;
|
||||
onSuccess?: () => void;
|
||||
}
|
||||
|
||||
export const GmailJsonUploadSection = ({
|
||||
@@ -166,37 +154,16 @@ export const GmailJsonUploadSection = ({
|
||||
appCredentialData,
|
||||
serviceAccountCredentialData,
|
||||
isAdmin,
|
||||
onSuccess,
|
||||
}: DriveJsonUploadSectionProps) => {
|
||||
const { mutate } = useSWRConfig();
|
||||
const router = useRouter();
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountCredentialData
|
||||
);
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountCredentialData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
}, [serviceAccountCredentialData, appCredentialData]);
|
||||
|
||||
const handleSuccess = () => {
|
||||
if (onSuccess) {
|
||||
onSuccess();
|
||||
} else {
|
||||
refreshAllGoogleData(ValidSources.Gmail);
|
||||
}
|
||||
};
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
if (serviceAccountCredentialData?.service_account_email) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing service account key with the following <b>Email:</b>
|
||||
<p className="italic mt-1">
|
||||
{localServiceAccountData.service_account_email}
|
||||
{serviceAccountCredentialData.service_account_email}
|
||||
</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
@@ -218,15 +185,10 @@ export const GmailJsonUploadSection = ({
|
||||
mutate(
|
||||
"/api/manage/admin/connector/gmail/service-account-key"
|
||||
);
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalServiceAccountData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
@@ -250,56 +212,43 @@ export const GmailJsonUploadSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
if (appCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="mt-2 text-sm">
|
||||
<div>
|
||||
Found existing app credentials with the following <b>Client ID:</b>
|
||||
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
|
||||
<p className="italic mt-1">{appCredentialData.client_id}</p>
|
||||
</div>
|
||||
{isAdmin ? (
|
||||
<>
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
// Also mutate the credential endpoints to ensure Step 2 is reset
|
||||
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
|
||||
setPopup({
|
||||
message: "Successfully deleted app credentials",
|
||||
type: "success",
|
||||
});
|
||||
// Immediately update local state
|
||||
setLocalAppCredentialData(undefined);
|
||||
handleSuccess();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-4 mb-1">
|
||||
To change these credentials, please contact an administrator.
|
||||
</div>
|
||||
)}
|
||||
<div className="mt-4 mb-1">
|
||||
If you want to update these credentials, delete the existing
|
||||
credentials through the button below, and then upload a new
|
||||
credentials JSON.
|
||||
</div>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
mutate("/api/manage/admin/connector/gmail/app-credential");
|
||||
setPopup({
|
||||
message: "Successfully deleted service account key",
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
message: `Failed to delete app credential - ${errorMsg}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -327,14 +276,14 @@ export const GmailJsonUploadSection = ({
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
to either (1) setup a Google OAuth App in your company workspace or (2)
|
||||
to either (1) setup a google OAuth App in your company workspace or (2)
|
||||
create a Service Account.
|
||||
<br />
|
||||
<br />
|
||||
Download the credentials JSON if choosing option (1) or the Service
|
||||
Account key JSON if choosing option (2), and upload it here.
|
||||
Account key JSON if chooosing option (2), and upload it here.
|
||||
</p>
|
||||
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
|
||||
<DriveJsonUpload setPopup={setPopup} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -350,34 +299,6 @@ interface DriveCredentialSectionProps {
|
||||
user: User | null;
|
||||
}
|
||||
|
||||
async function handleRevokeAccess(
|
||||
connectorExists: boolean,
|
||||
setPopup: (popupSpec: PopupSpec | null) => void,
|
||||
existingCredential:
|
||||
| Credential<GmailCredentialJson>
|
||||
| Credential<GmailServiceAccountCredentialJson>,
|
||||
refreshCredentials: () => void
|
||||
) {
|
||||
if (connectorExists) {
|
||||
const message =
|
||||
"Cannot revoke the Gmail credential while any connector is still associated with the credential. " +
|
||||
"Please delete all associated connectors, then try again.";
|
||||
setPopup({
|
||||
message: message,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await adminDeleteCredential(existingCredential.id);
|
||||
setPopup({
|
||||
message: "Successfully revoked the Gmail credential!",
|
||||
type: "success",
|
||||
});
|
||||
|
||||
refreshCredentials();
|
||||
}
|
||||
|
||||
export const GmailAuthSection = ({
|
||||
gmailPublicCredential,
|
||||
gmailServiceAccountCredential,
|
||||
@@ -389,49 +310,31 @@ export const GmailAuthSection = ({
|
||||
user,
|
||||
}: DriveCredentialSectionProps) => {
|
||||
const router = useRouter();
|
||||
const [isAuthenticating, setIsAuthenticating] = useState(false);
|
||||
const [localServiceAccountData, setLocalServiceAccountData] = useState(
|
||||
serviceAccountKeyData
|
||||
);
|
||||
const [localAppCredentialData, setLocalAppCredentialData] =
|
||||
useState(appCredentialData);
|
||||
const [localGmailPublicCredential, setLocalGmailPublicCredential] = useState(
|
||||
gmailPublicCredential
|
||||
);
|
||||
const [
|
||||
localGmailServiceAccountCredential,
|
||||
setLocalGmailServiceAccountCredential,
|
||||
] = useState(gmailServiceAccountCredential);
|
||||
|
||||
// Update local state when props change
|
||||
useEffect(() => {
|
||||
setLocalServiceAccountData(serviceAccountKeyData);
|
||||
setLocalAppCredentialData(appCredentialData);
|
||||
setLocalGmailPublicCredential(gmailPublicCredential);
|
||||
setLocalGmailServiceAccountCredential(gmailServiceAccountCredential);
|
||||
}, [
|
||||
serviceAccountKeyData,
|
||||
appCredentialData,
|
||||
gmailPublicCredential,
|
||||
gmailServiceAccountCredential,
|
||||
]);
|
||||
|
||||
const existingCredential =
|
||||
localGmailPublicCredential || localGmailServiceAccountCredential;
|
||||
gmailPublicCredential || gmailServiceAccountCredential;
|
||||
if (existingCredential) {
|
||||
return (
|
||||
<>
|
||||
<p className="mb-2 text-sm">
|
||||
<i>Uploaded and authenticated credential already exists!</i>
|
||||
<i>Existing credential already set up!</i>
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
handleRevokeAccess(
|
||||
connectorExists,
|
||||
setPopup,
|
||||
existingCredential,
|
||||
refreshCredentials
|
||||
);
|
||||
if (connectorExists) {
|
||||
setPopup({
|
||||
message:
|
||||
"Cannot revoke access to Gmail while any connector is still set up. Please delete all connectors, then try again.",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
await adminDeleteCredential(existingCredential.id);
|
||||
setPopup({
|
||||
message: "Successfully revoked access to Gmail!",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
Revoke Access
|
||||
@@ -440,21 +343,20 @@ export const GmailAuthSection = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (localServiceAccountData?.service_account_email) {
|
||||
if (serviceAccountKeyData?.service_account_email) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string()
|
||||
.email("Must be a valid email")
|
||||
.required("Required"),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
<CardSection>
|
||||
<Formik
|
||||
initialValues={{
|
||||
google_primary_admin: user?.email || "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
google_primary_admin: Yup.string().required(),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
const response = await fetch(
|
||||
"/api/manage/admin/connector/gmail/service-account-credential",
|
||||
{
|
||||
@@ -473,7 +375,6 @@ export const GmailAuthSection = ({
|
||||
message: "Successfully created service account credential",
|
||||
type: "success",
|
||||
});
|
||||
refreshCredentials();
|
||||
} else {
|
||||
const errorMsg = await response.text();
|
||||
setPopup({
|
||||
@@ -481,73 +382,65 @@ export const GmailAuthSection = ({
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to create service account credential - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
} finally {
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
Create Credential
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
refreshCredentials();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="google_primary_admin"
|
||||
label="Primary Admin Email:"
|
||||
subtext="You must provide an admin/owner account to retrieve all org emails."
|
||||
/>
|
||||
<div className="flex">
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className={
|
||||
"bg-slate-500 hover:bg-slate-700 text-white " +
|
||||
"font-bold py-2 px-4 rounded focus:outline-none " +
|
||||
"focus:shadow-outline w-full max-w-sm mx-auto"
|
||||
}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</CardSection>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (localAppCredentialData?.client_id) {
|
||||
if (appCredentialData?.client_id) {
|
||||
return (
|
||||
<div className="text-sm mb-4">
|
||||
<p className="mb-2">
|
||||
Next, you must provide credentials via OAuth. This gives us read
|
||||
access to the emails you have access to in your Gmail account.
|
||||
access to the docs you have access to in your gmail account.
|
||||
</p>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
const [authUrl, errorMsg] = await setupGmailOAuth({
|
||||
isAdmin: true,
|
||||
});
|
||||
if (authUrl) {
|
||||
// cookie used by callback to determine where to finally redirect to
|
||||
Cookies.set(GMAIL_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
const [authUrl, errorMsg] = await setupGmailOAuth({
|
||||
isAdmin: true,
|
||||
});
|
||||
|
||||
if (authUrl) {
|
||||
router.push(authUrl);
|
||||
} else {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: `Failed to authenticate with Gmail - ${error}`,
|
||||
type: "error",
|
||||
});
|
||||
setIsAuthenticating(false);
|
||||
router.push(authUrl);
|
||||
return;
|
||||
}
|
||||
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
}}
|
||||
disabled={isAuthenticating}
|
||||
>
|
||||
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
|
||||
Authenticate with Gmail
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
@@ -556,8 +449,8 @@ export const GmailAuthSection = ({
|
||||
// case where no keys have been uploaded in step 1
|
||||
return (
|
||||
<p className="text-sm">
|
||||
Please upload either a OAuth Client Credential JSON or a Gmail Service
|
||||
Account Key JSON in Step 1 before moving onto Step 2.
|
||||
Please upload an OAuth or Service Account Credential JSON in Step 1 before
|
||||
moving onto Step 2.
|
||||
</p>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { CCPairBasicInfo, ValidSources } from "@/lib/types";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { CCPairBasicInfo } from "@/lib/types";
|
||||
import {
|
||||
Credential,
|
||||
GmailCredentialJson,
|
||||
@@ -15,33 +14,26 @@ import { GmailAuthSection, GmailJsonUploadSection } from "./Credential";
|
||||
import { usePublicCredentials, useBasicConnectorStatus } from "@/lib/hooks";
|
||||
import Title from "@/components/ui/title";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import {
|
||||
useGoogleAppCredential,
|
||||
useGoogleServiceAccountKey,
|
||||
useGoogleCredentials,
|
||||
useConnectorsByCredentialId,
|
||||
checkCredentialsFetched,
|
||||
filterUploadedCredentials,
|
||||
checkConnectorsExist,
|
||||
refreshAllGoogleData,
|
||||
} from "@/lib/googleConnector";
|
||||
|
||||
export const GmailMain = () => {
|
||||
const { isAdmin, user } = useUser();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const {
|
||||
data: appCredentialData,
|
||||
isLoading: isAppCredentialLoading,
|
||||
error: isAppCredentialError,
|
||||
} = useGoogleAppCredential("gmail");
|
||||
|
||||
} = useSWR<{ client_id: string }>(
|
||||
"/api/manage/admin/connector/gmail/app-credential",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const {
|
||||
data: serviceAccountKeyData,
|
||||
isLoading: isServiceAccountKeyLoading,
|
||||
error: isServiceAccountKeyError,
|
||||
} = useGoogleServiceAccountKey("gmail");
|
||||
|
||||
} = useSWR<{ service_account_email: string }>(
|
||||
"/api/manage/admin/connector/gmail/service-account-key",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const {
|
||||
data: connectorIndexingStatuses,
|
||||
isLoading: isConnectorIndexingStatusesLoading,
|
||||
@@ -55,45 +47,20 @@ export const GmailMain = () => {
|
||||
refreshCredentials,
|
||||
} = usePublicCredentials();
|
||||
|
||||
const {
|
||||
data: gmailCredentials,
|
||||
isLoading: isGmailCredentialsLoading,
|
||||
error: gmailCredentialsError,
|
||||
} = useGoogleCredentials(ValidSources.Gmail);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const { credential_id, uploadedCredentials } =
|
||||
filterUploadedCredentials(gmailCredentials);
|
||||
|
||||
const {
|
||||
data: gmailConnectors,
|
||||
isLoading: isGmailConnectorsLoading,
|
||||
error: gmailConnectorsError,
|
||||
refreshConnectorsByCredentialId,
|
||||
} = useConnectorsByCredentialId(credential_id);
|
||||
|
||||
const {
|
||||
appCredentialSuccessfullyFetched,
|
||||
serviceAccountKeySuccessfullyFetched,
|
||||
} = checkCredentialsFetched(
|
||||
appCredentialData,
|
||||
isAppCredentialError,
|
||||
serviceAccountKeyData,
|
||||
isServiceAccountKeyError
|
||||
);
|
||||
|
||||
const handleRefresh = () => {
|
||||
refreshCredentials();
|
||||
refreshConnectorsByCredentialId();
|
||||
refreshAllGoogleData(ValidSources.Gmail);
|
||||
};
|
||||
const appCredentialSuccessfullyFetched =
|
||||
appCredentialData ||
|
||||
(isAppCredentialError && isAppCredentialError.status === 404);
|
||||
const serviceAccountKeySuccessfullyFetched =
|
||||
serviceAccountKeyData ||
|
||||
(isServiceAccountKeyError && isServiceAccountKeyError.status === 404);
|
||||
|
||||
if (
|
||||
(!appCredentialSuccessfullyFetched && isAppCredentialLoading) ||
|
||||
(!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) ||
|
||||
(!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) ||
|
||||
(!credentialsData && isCredentialsLoading) ||
|
||||
(!gmailCredentials && isGmailCredentialsLoading) ||
|
||||
(!gmailConnectors && isGmailConnectorsLoading)
|
||||
(!credentialsData && isCredentialsLoading)
|
||||
) {
|
||||
return (
|
||||
<div className="mx-auto">
|
||||
@@ -103,15 +70,19 @@ export const GmailMain = () => {
|
||||
}
|
||||
|
||||
if (credentialsError || !credentialsData) {
|
||||
return <ErrorCallout errorTitle="Failed to load credentials." />;
|
||||
}
|
||||
|
||||
if (gmailCredentialsError || !gmailCredentials) {
|
||||
return <ErrorCallout errorTitle="Failed to load Gmail credentials." />;
|
||||
return (
|
||||
<div className="mx-auto">
|
||||
<div className="text-red-500">Failed to load credentials.</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (connectorIndexingStatusesError || !connectorIndexingStatuses) {
|
||||
return <ErrorCallout errorTitle="Failed to load connectors." />;
|
||||
return (
|
||||
<div className="mx-auto">
|
||||
<div className="text-red-500">Failed to load connectors.</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -119,28 +90,21 @@ export const GmailMain = () => {
|
||||
!serviceAccountKeySuccessfullyFetched
|
||||
) {
|
||||
return (
|
||||
<ErrorCallout errorTitle="Error loading Gmail app credentials. Contact an administrator." />
|
||||
<div className="mx-auto">
|
||||
<div className="text-red-500">
|
||||
Error loading Gmail app credentials. Contact an administrator.
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (gmailConnectorsError) {
|
||||
return (
|
||||
<ErrorCallout errorTitle="Failed to load Gmail associated connectors." />
|
||||
const gmailPublicCredential: Credential<GmailCredentialJson> | undefined =
|
||||
credentialsData.find(
|
||||
(credential) =>
|
||||
(credential.credential_json?.google_service_account_key ||
|
||||
credential.credential_json?.google_tokens) &&
|
||||
credential.admin_public
|
||||
);
|
||||
}
|
||||
|
||||
const connectorExistsFromCredential = checkConnectorsExist(gmailConnectors);
|
||||
|
||||
const gmailPublicUploadedCredential:
|
||||
| Credential<GmailCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
(credential) =>
|
||||
credential.credential_json?.google_tokens &&
|
||||
credential.admin_public &&
|
||||
credential.source === "gmail" &&
|
||||
credential.credential_json.authentication_method !== "oauth_interactive"
|
||||
);
|
||||
|
||||
const gmailServiceAccountCredential:
|
||||
| Credential<GmailServiceAccountCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
@@ -154,13 +118,6 @@ export const GmailMain = () => {
|
||||
(connectorIndexingStatus) => connectorIndexingStatus.source === "gmail"
|
||||
);
|
||||
|
||||
const connectorExists =
|
||||
connectorExistsFromCredential || gmailConnectorIndexingStatuses.length > 0;
|
||||
|
||||
const hasUploadedCredentials =
|
||||
Boolean(appCredentialData?.client_id) ||
|
||||
Boolean(serviceAccountKeyData?.service_account_email);
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
@@ -172,22 +129,21 @@ export const GmailMain = () => {
|
||||
appCredentialData={appCredentialData}
|
||||
serviceAccountCredentialData={serviceAccountKeyData}
|
||||
isAdmin={isAdmin}
|
||||
onSuccess={handleRefresh}
|
||||
/>
|
||||
|
||||
{isAdmin && hasUploadedCredentials && (
|
||||
{isAdmin && (
|
||||
<>
|
||||
<Title className="mb-2 mt-6 ml-auto mr-auto">
|
||||
Step 2: Authenticate with Onyx
|
||||
</Title>
|
||||
<GmailAuthSection
|
||||
setPopup={setPopup}
|
||||
refreshCredentials={handleRefresh}
|
||||
gmailPublicCredential={gmailPublicUploadedCredential}
|
||||
refreshCredentials={refreshCredentials}
|
||||
gmailPublicCredential={gmailPublicCredential}
|
||||
gmailServiceAccountCredential={gmailServiceAccountCredential}
|
||||
appCredentialData={appCredentialData}
|
||||
serviceAccountKeyData={serviceAccountKeyData}
|
||||
connectorExists={connectorExists}
|
||||
connectorExists={gmailConnectorIndexingStatuses.length > 0}
|
||||
user={user}
|
||||
/>
|
||||
</>
|
||||
|
||||
@@ -4,12 +4,6 @@ export enum ApplicationStatus {
|
||||
ACTIVE = "active",
|
||||
}
|
||||
|
||||
export enum QueryHistoryType {
|
||||
DISABLED = "disabled",
|
||||
ANONYMIZED = "anonymized",
|
||||
NORMAL = "normal",
|
||||
}
|
||||
|
||||
export interface Settings {
|
||||
anonymous_user_enabled: boolean;
|
||||
maximum_chat_retention_days: number | null;
|
||||
@@ -20,7 +14,6 @@ export interface Settings {
|
||||
application_status: ApplicationStatus;
|
||||
auto_scroll: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
query_history_type: QueryHistoryType;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
||||
@@ -36,14 +36,11 @@ export function EmailPasswordForm({
|
||||
{popup}
|
||||
<Formik
|
||||
initialValues={{
|
||||
email: defaultEmail ? defaultEmail.toLowerCase() : "",
|
||||
email: defaultEmail || "",
|
||||
password: "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
email: Yup.string()
|
||||
.email()
|
||||
.required()
|
||||
.transform((value) => value.toLowerCase()),
|
||||
email: Yup.string().email().required(),
|
||||
password: Yup.string().required(),
|
||||
})}
|
||||
onSubmit={async (values) => {
|
||||
|
||||
@@ -56,7 +56,6 @@ import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
use,
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
@@ -132,12 +131,18 @@ import {
|
||||
|
||||
import { getSourceMetadata } from "@/lib/sources";
|
||||
import { UserSettingsModal } from "./modal/UserSettingsModal";
|
||||
import { AlignStartVertical } from "lucide-react";
|
||||
import { AgenticMessage } from "./message/AgenticMessage";
|
||||
import AssistantModal from "../assistants/mine/AssistantModal";
|
||||
import { useSidebarShortcut } from "@/lib/browserUtilities";
|
||||
import {
|
||||
OperatingSystem,
|
||||
useOperatingSystem,
|
||||
useSidebarShortcut,
|
||||
} from "@/lib/browserUtilities";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { MessageChannel } from "node:worker_threads";
|
||||
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
|
||||
import { ErrorBanner } from "./message/Resubmit";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -888,6 +893,24 @@ export function ChatPage({
|
||||
);
|
||||
const scrollDist = useRef<number>(0);
|
||||
|
||||
const updateScrollTracking = () => {
|
||||
const scrollDistance =
|
||||
endDivRef?.current?.getBoundingClientRect()?.top! -
|
||||
inputRef?.current?.getBoundingClientRect()?.top!;
|
||||
scrollDist.current = scrollDistance;
|
||||
setAboveHorizon(scrollDist.current > 500);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (scrollableDiv) {
|
||||
scrollableDiv.addEventListener("scroll", updateScrollTracking);
|
||||
return () => {
|
||||
scrollableDiv.removeEventListener("scroll", updateScrollTracking);
|
||||
};
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleInputResize = () => {
|
||||
setTimeout(() => {
|
||||
if (
|
||||
@@ -939,12 +962,33 @@ export function ChatPage({
|
||||
if (isVisible) return;
|
||||
|
||||
// Check if all messages are currently rendered
|
||||
// If all messages are already rendered, scroll immediately
|
||||
endDivRef.current.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
if (currentVisibleRange.end < messageHistory.length) {
|
||||
// Update visible range to include the last messages
|
||||
updateCurrentVisibleRange({
|
||||
start: Math.max(
|
||||
0,
|
||||
messageHistory.length -
|
||||
(currentVisibleRange.end - currentVisibleRange.start)
|
||||
),
|
||||
end: messageHistory.length,
|
||||
mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
|
||||
});
|
||||
|
||||
setHasPerformedInitialScroll(true);
|
||||
// Wait for the state update and re-render before scrolling
|
||||
setTimeout(() => {
|
||||
endDivRef.current?.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
setHasPerformedInitialScroll(true);
|
||||
}, 100);
|
||||
} else {
|
||||
// If all messages are already rendered, scroll immediately
|
||||
endDivRef.current.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
|
||||
setHasPerformedInitialScroll(true);
|
||||
}
|
||||
}, 50);
|
||||
|
||||
// Reset waitForScrollRef after 1.5 seconds
|
||||
@@ -965,6 +1009,11 @@ export function ChatPage({
|
||||
handleInputResize();
|
||||
}, [message]);
|
||||
|
||||
// tracks scrolling
|
||||
useEffect(() => {
|
||||
updateScrollTracking();
|
||||
}, [messageHistory]);
|
||||
|
||||
// used for resizing of the document sidebar
|
||||
const masterFlexboxRef = useRef<HTMLDivElement>(null);
|
||||
const [maxDocumentSidebarWidth, setMaxDocumentSidebarWidth] = useState<
|
||||
@@ -1163,7 +1212,6 @@ export function ChatPage({
|
||||
navigatingAway.current = false;
|
||||
let frozenSessionId = currentSessionId();
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
setUncaughtError(null);
|
||||
|
||||
// Mark that we've sent a message for this session in the current page load
|
||||
markSessionMessageSent(frozenSessionId);
|
||||
@@ -1314,7 +1362,6 @@ export function ChatPage({
|
||||
let isStreamingQuestions = true;
|
||||
let includeAgentic = false;
|
||||
let secondLevelMessageId: number | null = null;
|
||||
let isAgentic: boolean = false;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
@@ -1477,9 +1524,6 @@ export function ChatPage({
|
||||
second_level_generating = true;
|
||||
}
|
||||
}
|
||||
if (Object.hasOwn(packet, "is_agentic")) {
|
||||
isAgentic = (packet as any).is_agentic;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "refined_answer_improvement")) {
|
||||
isImprovement = (packet as RefinedAnswerImprovement)
|
||||
@@ -1513,7 +1557,6 @@ export function ChatPage({
|
||||
);
|
||||
} else if (Object.hasOwn(packet, "sub_question")) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
isAgentic = true;
|
||||
is_generating = true;
|
||||
sub_questions = constructSubQuestions(
|
||||
sub_questions,
|
||||
@@ -1714,7 +1757,6 @@ export function ChatPage({
|
||||
sub_questions: sub_questions,
|
||||
second_level_generating: second_level_generating,
|
||||
agentic_docs: agenticDocs,
|
||||
is_agentic: isAgentic,
|
||||
},
|
||||
...(includeAgentic
|
||||
? [
|
||||
@@ -1935,6 +1977,122 @@ export function ChatPage({
|
||||
|
||||
// Virtualization + Scrolling related effects and functions
|
||||
const scrollInitialized = useRef(false);
|
||||
interface VisibleRange {
|
||||
start: number;
|
||||
end: number;
|
||||
mostVisibleMessageId: number | null;
|
||||
}
|
||||
|
||||
const [visibleRange, setVisibleRange] = useState<
|
||||
Map<string | null, VisibleRange>
|
||||
>(() => {
|
||||
const initialRange: VisibleRange = {
|
||||
start: 0,
|
||||
end: BUFFER_COUNT,
|
||||
mostVisibleMessageId: null,
|
||||
};
|
||||
return new Map([[chatSessionIdRef.current, initialRange]]);
|
||||
});
|
||||
|
||||
// Function used to update current visible range. Only method for updating `visibleRange` state.
|
||||
const updateCurrentVisibleRange = (
|
||||
newRange: VisibleRange,
|
||||
forceUpdate?: boolean
|
||||
) => {
|
||||
if (
|
||||
scrollInitialized.current &&
|
||||
visibleRange.get(loadedIdSessionRef.current) == undefined &&
|
||||
!forceUpdate
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setVisibleRange((prevState) => {
|
||||
const newState = new Map(prevState);
|
||||
newState.set(loadedIdSessionRef.current, newRange);
|
||||
return newState;
|
||||
});
|
||||
};
|
||||
|
||||
// Set first value for visibleRange state on page load / refresh.
|
||||
const initializeVisibleRange = () => {
|
||||
const upToDatemessageHistory = buildLatestMessageChain(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
|
||||
if (!scrollInitialized.current && upToDatemessageHistory.length > 0) {
|
||||
const newEnd = Math.max(upToDatemessageHistory.length, BUFFER_COUNT);
|
||||
const newStart = Math.max(0, newEnd - BUFFER_COUNT);
|
||||
const newMostVisibleMessageId =
|
||||
upToDatemessageHistory[newEnd - 1]?.messageId;
|
||||
|
||||
updateCurrentVisibleRange(
|
||||
{
|
||||
start: newStart,
|
||||
end: newEnd,
|
||||
mostVisibleMessageId: newMostVisibleMessageId,
|
||||
},
|
||||
true
|
||||
);
|
||||
scrollInitialized.current = true;
|
||||
}
|
||||
};
|
||||
|
||||
const updateVisibleRangeBasedOnScroll = () => {
|
||||
if (!scrollInitialized.current) return;
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (!scrollableDiv) return;
|
||||
|
||||
const viewportHeight = scrollableDiv.clientHeight;
|
||||
let mostVisibleMessageIndex = -1;
|
||||
|
||||
messageHistory.forEach((message, index) => {
|
||||
const messageElement = document.getElementById(
|
||||
`message-${message.messageId}`
|
||||
);
|
||||
if (messageElement) {
|
||||
const rect = messageElement.getBoundingClientRect();
|
||||
const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0;
|
||||
if (isVisible && index > mostVisibleMessageIndex) {
|
||||
mostVisibleMessageIndex = index;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (mostVisibleMessageIndex !== -1) {
|
||||
const startIndex = Math.max(0, mostVisibleMessageIndex - BUFFER_COUNT);
|
||||
const endIndex = Math.min(
|
||||
messageHistory.length,
|
||||
mostVisibleMessageIndex + BUFFER_COUNT + 1
|
||||
);
|
||||
|
||||
updateCurrentVisibleRange({
|
||||
start: startIndex,
|
||||
end: endIndex,
|
||||
mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
initializeVisibleRange();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [router, messageHistory]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
|
||||
const handleScroll = () => {
|
||||
updateVisibleRangeBasedOnScroll();
|
||||
};
|
||||
|
||||
scrollableDiv?.addEventListener("scroll", handleScroll);
|
||||
|
||||
return () => {
|
||||
scrollableDiv?.removeEventListener("scroll", handleScroll);
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [messageHistory]);
|
||||
|
||||
const imageFileInMessageHistory = useMemo(() => {
|
||||
return messageHistory
|
||||
@@ -1944,6 +2102,11 @@ export function ChatPage({
|
||||
);
|
||||
}, [messageHistory]);
|
||||
|
||||
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
|
||||
start: 0,
|
||||
end: 0,
|
||||
mostVisibleMessageId: null,
|
||||
};
|
||||
useSendMessageToParent();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -1983,15 +2146,6 @@ export function ChatPage({
|
||||
|
||||
const currentPersona = alternativeAssistant || liveAssistant;
|
||||
|
||||
const HORIZON_DISTANCE = 800;
|
||||
const handleScroll = useCallback(() => {
|
||||
const scrollDistance =
|
||||
endDivRef?.current?.getBoundingClientRect()?.top! -
|
||||
inputRef?.current?.getBoundingClientRect()?.top!;
|
||||
scrollDist.current = scrollDistance;
|
||||
setAboveHorizon(scrollDist.current > HORIZON_DISTANCE);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleSlackChatRedirect = async () => {
|
||||
if (!slackChatId) return;
|
||||
@@ -2063,26 +2217,6 @@ export function ChatPage({
|
||||
const [sharedChatSession, setSharedChatSession] =
|
||||
useState<ChatSession | null>();
|
||||
|
||||
const handleResubmitLastMessage = () => {
|
||||
// Grab the last user-type message
|
||||
const lastUserMsg = messageHistory
|
||||
.slice()
|
||||
.reverse()
|
||||
.find((m) => m.type === "user");
|
||||
if (!lastUserMsg) {
|
||||
setPopup({
|
||||
message: "No previously-submitted user message found.",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
// We call onSubmit, passing a `messageOverride`
|
||||
onSubmit({
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
messageOverride: lastUserMsg.message,
|
||||
});
|
||||
};
|
||||
|
||||
const showShareModal = (chatSession: ChatSession) => {
|
||||
setSharedChatSession(chatSession);
|
||||
};
|
||||
@@ -2462,7 +2596,6 @@ export function ChatPage({
|
||||
{...getRootProps()}
|
||||
>
|
||||
<div
|
||||
onScroll={handleScroll}
|
||||
className={`w-full h-[calc(100vh-160px)] flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
|
||||
ref={scrollableDivRef}
|
||||
>
|
||||
@@ -2520,7 +2653,18 @@ export function ChatPage({
|
||||
// NOTE: temporarily removing this to fix the scroll bug
|
||||
// (hasPerformedInitialScroll ? "" : "invisible")
|
||||
>
|
||||
{messageHistory.map((message, i) => {
|
||||
{(messageHistory.length < BUFFER_COUNT
|
||||
? messageHistory
|
||||
: messageHistory.slice(
|
||||
currentVisibleRange.start,
|
||||
currentVisibleRange.end
|
||||
)
|
||||
).map((message, fauxIndex) => {
|
||||
const i =
|
||||
messageHistory.length < BUFFER_COUNT
|
||||
? fauxIndex
|
||||
: fauxIndex + currentVisibleRange.start;
|
||||
|
||||
const messageMap = currentMessageMap(
|
||||
completeMessageDetail
|
||||
);
|
||||
@@ -2665,9 +2809,9 @@ export function ChatPage({
|
||||
: null
|
||||
}
|
||||
>
|
||||
{message.is_agentic ? (
|
||||
{message.sub_questions &&
|
||||
message.sub_questions.length > 0 ? (
|
||||
<AgenticMessage
|
||||
resubmit={handleResubmitLastMessage}
|
||||
error={uncaughtError}
|
||||
isStreamingQuestions={
|
||||
message.isStreamingQuestions ?? false
|
||||
@@ -3015,18 +3159,21 @@ export function ChatPage({
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
content={
|
||||
<ErrorBanner
|
||||
resubmit={handleResubmitLastMessage}
|
||||
error={message.message}
|
||||
showStackTrace={
|
||||
message.stackTrace
|
||||
? () =>
|
||||
setStackTraceModalContent(
|
||||
message.stackTrace!
|
||||
)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{message.message}
|
||||
{message.stackTrace && (
|
||||
<span
|
||||
onClick={() =>
|
||||
setStackTraceModalContent(
|
||||
message.stackTrace!
|
||||
)
|
||||
}
|
||||
className="ml-2 cursor-pointer underline"
|
||||
>
|
||||
Show stack trace.
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -15,8 +15,8 @@ export function ChatSearchGroup({
|
||||
}: ChatSearchGroupProps) {
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-neutral-800/90 py-2 px-4 px-4">
|
||||
<div className="text-xs font-medium leading-4 text-neutral-600 dark:text-neutral-400">
|
||||
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
|
||||
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
|
||||
{title}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from "react";
|
||||
import { MessageSquare } from "lucide-react";
|
||||
import { ChatSessionSummary } from "./interfaces";
|
||||
import { truncateString } from "@/lib/utils";
|
||||
|
||||
interface ChatSearchItemProps {
|
||||
chat: ChatSessionSummary;
|
||||
@@ -12,12 +11,12 @@ export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
|
||||
return (
|
||||
<li>
|
||||
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
|
||||
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-700">
|
||||
<div className="flex max-w-full mx-2 items-center">
|
||||
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
|
||||
<div className="flex items-center">
|
||||
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
|
||||
<div className="relative max-w-full grow overflow-hidden whitespace-nowrap pl-4">
|
||||
<div className="text-sm max-w-full dark:text-neutral-200">
|
||||
{truncateString(chat.name || "Untitled Chat", 90)}
|
||||
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
|
||||
<div className="text-sm dark:text-neutral-200">
|
||||
{chat.name || "Untitled Chat"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">
|
||||
|
||||
@@ -168,7 +168,7 @@ const FolderItem = ({
|
||||
};
|
||||
|
||||
const folders = folder.chat_sessions.sort((a, b) => {
|
||||
return a.time_updated.localeCompare(b.time_updated);
|
||||
return a.time_created.localeCompare(b.time_created);
|
||||
});
|
||||
|
||||
// Determine whether to show the trash can icon
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useState, useEffect, useCallback } from "react";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { InputPrompt } from "@/app/chat/interfaces";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { PlusIcon } from "@/components/icons/icons";
|
||||
import { MoreVertical, XIcon } from "lucide-react";
|
||||
import { TrashIcon, PlusIcon } from "@/components/icons/icons";
|
||||
import { MoreVertical, CheckIcon, XIcon } from "lucide-react";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import Title from "@/components/ui/title";
|
||||
import Text from "@/components/ui/text";
|
||||
@@ -153,6 +153,114 @@ export default function InputPrompts() {
|
||||
}
|
||||
};
|
||||
|
||||
const PromptCard = ({ prompt }: { prompt: InputPrompt }) => {
|
||||
const isEditing = editingPromptId === prompt.id;
|
||||
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
|
||||
const [localContent, setLocalContent] = useState(prompt.content);
|
||||
|
||||
// Sync local edits with any prompt changes from outside
|
||||
useEffect(() => {
|
||||
setLocalPrompt(prompt.prompt);
|
||||
setLocalContent(prompt.content);
|
||||
}, [prompt, isEditing]);
|
||||
|
||||
const handleLocalEdit = (field: "prompt" | "content", value: string) => {
|
||||
if (field === "prompt") {
|
||||
setLocalPrompt(value);
|
||||
} else {
|
||||
setLocalContent(value);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveLocal = () => {
|
||||
handleSave(prompt.id, localPrompt, localContent);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<div className="absolute top-2 right-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setEditingPromptId(null);
|
||||
fetchInputPrompts(); // Revert changes from server
|
||||
}}
|
||||
>
|
||||
<XIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex-grow mr-4">
|
||||
<Textarea
|
||||
value={localPrompt}
|
||||
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
|
||||
className="mb-2 resize-none"
|
||||
placeholder="Prompt"
|
||||
/>
|
||||
<Textarea
|
||||
value={localContent}
|
||||
onChange={(e) => handleLocalEdit("content", e.target.value)}
|
||||
className="resize-vertical min-h-[100px]"
|
||||
placeholder="Content"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<Button onClick={handleSaveLocal}>
|
||||
{prompt.id ? "Save" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="mb-2 flex gap-x-2 ">
|
||||
<p className="font-semibold">{prompt.prompt}</p>
|
||||
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{isPromptPublic(prompt) && (
|
||||
<TooltipContent>
|
||||
<p>This is a built-in prompt and cannot be edited</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="whitespace-pre-wrap">{prompt.content}</div>
|
||||
<div className="absolute top-2 right-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
|
||||
<Button
|
||||
className="!hover:bg-transparent"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<MoreVertical size={14} />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{!isPromptPublic(prompt) && (
|
||||
<DropdownMenuItem onClick={() => handleEdit(prompt.id)}>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => handleDelete(prompt.id)}>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mx-auto max-w-4xl">
|
||||
<div className="absolute top-4 left-4">
|
||||
@@ -164,21 +272,13 @@ export default function InputPrompts() {
|
||||
<Title>Prompt Shortcuts</Title>
|
||||
<Text>
|
||||
Manage and customize prompt shortcuts for your assistants. Use your
|
||||
prompt shortcuts by starting a new message with "/" in
|
||||
chat.
|
||||
prompt shortcuts by starting a new message “/” in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{inputPrompts.map((prompt) => (
|
||||
<PromptCard
|
||||
key={prompt.id}
|
||||
prompt={prompt}
|
||||
onEdit={handleEdit}
|
||||
onSave={handleSave}
|
||||
onDelete={handleDelete}
|
||||
isEditing={editingPromptId === prompt.id}
|
||||
/>
|
||||
<PromptCard key={prompt.id} prompt={prompt} />
|
||||
))}
|
||||
|
||||
{isCreatingNew ? (
|
||||
@@ -215,129 +315,3 @@ export default function InputPrompts() {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface PromptCardProps {
|
||||
prompt: InputPrompt;
|
||||
onEdit: (id: number) => void;
|
||||
onSave: (id: number, prompt: string, content: string) => void;
|
||||
onDelete: (id: number) => void;
|
||||
isEditing: boolean;
|
||||
}
|
||||
|
||||
const PromptCard: React.FC<PromptCardProps> = ({
|
||||
prompt,
|
||||
onEdit,
|
||||
onSave,
|
||||
onDelete,
|
||||
isEditing,
|
||||
}) => {
|
||||
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
|
||||
const [localContent, setLocalContent] = useState(prompt.content);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalPrompt(prompt.prompt);
|
||||
setLocalContent(prompt.content);
|
||||
}, [prompt, isEditing]);
|
||||
|
||||
const handleLocalEdit = useCallback(
|
||||
(field: "prompt" | "content", value: string) => {
|
||||
if (field === "prompt") {
|
||||
setLocalPrompt(value);
|
||||
} else {
|
||||
setLocalContent(value);
|
||||
}
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleSaveLocal = useCallback(() => {
|
||||
onSave(prompt.id, localPrompt, localContent);
|
||||
}, [prompt.id, localPrompt, localContent, onSave]);
|
||||
|
||||
const isPromptPublic = useCallback((p: InputPrompt): boolean => {
|
||||
return p.is_public;
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<div className="absolute top-2 right-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
onEdit(0);
|
||||
}}
|
||||
>
|
||||
<XIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex-grow mr-4">
|
||||
<Textarea
|
||||
value={localPrompt}
|
||||
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
|
||||
className="mb-2 resize-none"
|
||||
placeholder="Prompt"
|
||||
/>
|
||||
<Textarea
|
||||
value={localContent}
|
||||
onChange={(e) => handleLocalEdit("content", e.target.value)}
|
||||
className="resize-vertical min-h-[100px]"
|
||||
placeholder="Content"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<Button onClick={handleSaveLocal}>
|
||||
{prompt.id ? "Save" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="mb-2 flex gap-x-2 ">
|
||||
<p className="font-semibold">{prompt.prompt}</p>
|
||||
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{isPromptPublic(prompt) && (
|
||||
<TooltipContent>
|
||||
<p>This is a built-in prompt and cannot be edited</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="whitespace-pre-wrap">{prompt.content}</div>
|
||||
<div className="absolute top-2 right-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
|
||||
<Button
|
||||
className="!hover:bg-transparent"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<MoreVertical size={14} />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{!isPromptPublic(prompt) && (
|
||||
<DropdownMenuItem onClick={() => onEdit(prompt.id)}>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => onDelete(prompt.id)}>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
import { SourceChip } from "../input/ChatInputBar";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { InputPrompt } from "../interfaces";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { XIcon } from "@/components/icons/icons";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { MoreVertical } from "lucide-react";
|
||||
|
||||
export const PromptCard = ({
|
||||
prompt,
|
||||
editingPromptId,
|
||||
setEditingPromptId,
|
||||
handleSave,
|
||||
handleDelete,
|
||||
isPromptPublic,
|
||||
handleEdit,
|
||||
fetchInputPrompts,
|
||||
}: {
|
||||
prompt: InputPrompt;
|
||||
editingPromptId: number | null;
|
||||
setEditingPromptId: (id: number | null) => void;
|
||||
handleSave: (id: number, prompt: string, content: string) => void;
|
||||
handleDelete: (id: number) => void;
|
||||
isPromptPublic: (prompt: InputPrompt) => boolean;
|
||||
handleEdit: (id: number) => void;
|
||||
fetchInputPrompts: () => void;
|
||||
}) => {
|
||||
const isEditing = editingPromptId === prompt.id;
|
||||
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
|
||||
const [localContent, setLocalContent] = useState(prompt.content);
|
||||
|
||||
// Sync local edits with any prompt changes from outside
|
||||
useEffect(() => {
|
||||
setLocalPrompt(prompt.prompt);
|
||||
setLocalContent(prompt.content);
|
||||
}, [prompt, isEditing]);
|
||||
|
||||
const handleLocalEdit = (field: "prompt" | "content", value: string) => {
|
||||
if (field === "prompt") {
|
||||
setLocalPrompt(value);
|
||||
} else {
|
||||
setLocalContent(value);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveLocal = () => {
|
||||
handleSave(prompt.id, localPrompt, localContent);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<div className="absolute top-2 right-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setEditingPromptId(null);
|
||||
fetchInputPrompts(); // Revert changes from server
|
||||
}}
|
||||
>
|
||||
<XIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex-grow mr-4">
|
||||
<Textarea
|
||||
value={localPrompt}
|
||||
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
|
||||
className="mb-2 resize-none"
|
||||
placeholder="Prompt"
|
||||
/>
|
||||
<Textarea
|
||||
value={localContent}
|
||||
onChange={(e) => handleLocalEdit("content", e.target.value)}
|
||||
className="resize-vertical min-h-[100px]"
|
||||
placeholder="Content"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<Button onClick={handleSaveLocal}>
|
||||
{prompt.id ? "Save" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="mb-2 flex gap-x-2 ">
|
||||
<p className="font-semibold">{prompt.prompt}</p>
|
||||
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{isPromptPublic(prompt) && (
|
||||
<TooltipContent>
|
||||
<p>This is a built-in prompt and cannot be edited</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="whitespace-pre-wrap">{prompt.content}</div>
|
||||
<div className="absolute top-2 right-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
|
||||
<Button
|
||||
className="!hover:bg-transparent"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<MoreVertical size={14} />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{!isPromptPublic(prompt) && (
|
||||
<DropdownMenuItem onClick={() => handleEdit(prompt.id)}>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => handleDelete(prompt.id)}>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -154,6 +154,7 @@ export const SourceChip = ({
|
||||
gap-x-1
|
||||
h-6
|
||||
${onClick ? "cursor-pointer" : ""}
|
||||
animate-fade-in-scale
|
||||
`}
|
||||
>
|
||||
{icon}
|
||||
|
||||
@@ -70,7 +70,6 @@ export interface ChatSession {
|
||||
name: string;
|
||||
persona_id: number;
|
||||
time_created: string;
|
||||
time_updated: string;
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
folder_id: number | null;
|
||||
current_alternate_model: string;
|
||||
@@ -104,7 +103,6 @@ export interface Message {
|
||||
overridden_model?: string;
|
||||
stopReason?: StreamStopReason | null;
|
||||
sub_questions?: SubQuestionDetail[] | null;
|
||||
is_agentic?: boolean | null;
|
||||
|
||||
// Streaming only
|
||||
second_level_generating?: boolean;
|
||||
@@ -124,7 +122,6 @@ export interface BackendChatSession {
|
||||
persona_icon_shape: number | null;
|
||||
messages: BackendMessage[];
|
||||
time_created: string;
|
||||
time_updated: string;
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
current_temperature_override: number | null;
|
||||
current_alternate_model?: string;
|
||||
@@ -151,7 +148,6 @@ export interface BackendMessage {
|
||||
comments: any;
|
||||
parentMessageId: number | null;
|
||||
refined_answer_improvement: boolean | null;
|
||||
is_agentic: boolean | null;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
|
||||
@@ -48,10 +48,10 @@ export function getChatRetentionInfo(
|
||||
): ChatRetentionInfo {
|
||||
// If `maximum_chat_retention_days` isn't set- never display retention warning.
|
||||
const chatRetentionDays = settings.maximum_chat_retention_days || 10000;
|
||||
const updatedDate = new Date(chatSession.time_updated);
|
||||
const createdDate = new Date(chatSession.time_created);
|
||||
const today = new Date();
|
||||
const daysFromCreation = Math.ceil(
|
||||
(today.getTime() - updatedDate.getTime()) / (1000 * 3600 * 24)
|
||||
(today.getTime() - createdDate.getTime()) / (1000 * 3600 * 24)
|
||||
);
|
||||
const daysUntilExpiration = chatRetentionDays - daysFromCreation;
|
||||
const showRetentionWarning =
|
||||
@@ -419,7 +419,7 @@ export function groupSessionsByDateRange(chatSessions: ChatSession[]) {
|
||||
};
|
||||
|
||||
chatSessions.forEach((chatSession) => {
|
||||
const chatSessionDate = new Date(chatSession.time_updated);
|
||||
const chatSessionDate = new Date(chatSession.time_created);
|
||||
|
||||
const diffTime = today.getTime() - chatSessionDate.getTime();
|
||||
const diffDays = diffTime / (1000 * 3600 * 24); // Convert time difference to days
|
||||
@@ -501,7 +501,6 @@ export function processRawChatHistory(
|
||||
sub_questions: subQuestions,
|
||||
isImprovement:
|
||||
(messageInfo.refined_answer_improvement as unknown as boolean) || false,
|
||||
is_agentic: messageInfo.is_agentic,
|
||||
};
|
||||
|
||||
messages.set(messageInfo.message_id, message);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user