Compare commits

..

39 Commits

Author SHA1 Message Date
pablonyx
ddaaaeeb40 k 2025-02-27 20:28:49 -08:00
pablonyx
4e31bc19dc k 2025-02-27 17:02:11 -08:00
pablonyx
1c0c80116e fix function 2025-02-27 17:00:05 -08:00
pablonyx
27891ad34d k 2025-02-27 16:51:06 -08:00
pablonyx
511341fd7c k 2025-02-27 16:49:23 -08:00
pablonyx
88d4e7defa fix handling 2025-02-27 16:14:02 -08:00
pablonyx
a7ba0da8cc Lowercase multi tenant email mapping (#4141) 2025-02-27 15:33:40 -08:00
Richard Kuo (Danswer)
aaced6d551 scan images 2025-02-27 15:25:29 -08:00
Richard Kuo (Danswer)
4c230f92ea trivy test 2025-02-27 15:05:03 -08:00
Richard Kuo (Danswer)
07d75b04d1 enable trivy scan 2025-02-27 14:22:44 -08:00
evan-danswer
a8d10750c1 fix propagation of is_agentic (#4150) 2025-02-27 11:56:51 -08:00
pablonyx
85e3ed57f1 Order chat sessions by time updated, not created (#4143)
* order chat sessions by time updated, not created

* quick update

* k
2025-02-27 17:35:42 +00:00
pablonyx
e10cc8ccdb Multi tenant user google auth fix (#4145) 2025-02-27 10:35:38 -08:00
pablonyx
7018bc974b Better looking errors (#4050)
* add error handling

* fix

* k
2025-02-27 04:58:25 +00:00
pablonyx
9c9075d71d Minor improvements to provisioning (#4109)
* quick fix

* k

* nit
2025-02-27 04:57:31 +00:00
pablonyx
338e084062 Improved tenant handling for slack bot (#4099) 2025-02-27 04:06:26 +00:00
pablonyx
2f64031f5c Improved tenant handling for slack bot1 (#4104) 2025-02-27 03:40:50 +00:00
pablonyx
abb74f2eaa Improved chat search (#4137)
* functional + fast

* k

* adapt

* k

* nit

* k

* k

* fix typing

* k
2025-02-27 02:27:45 +00:00
pablonyx
a3e3d83b7e Improve viewable assistant logic (#4125)
* k

* quick fix

* k
2025-02-27 01:24:39 +00:00
pablonyx
4dc88ca037 debug playwright failure case 2025-02-26 17:32:26 -08:00
rkuo-danswer
11e7e1c4d6 log processed tenant count (#4139)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-26 17:26:48 -08:00
pablonyx
f2d74ce540 Address Auth Edge Case (#4138) 2025-02-26 17:24:23 -08:00
rkuo-danswer
25389c5120 first cut at anonymizing query history (#4123)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-26 21:32:01 +00:00
pablonyx
ad0721ecd8 update (#4086) 2025-02-26 18:12:07 +00:00
pablonyx
426a8842ae Markdown copying / html formatting (#4120)
* k

* delete unnecessary util
2025-02-26 04:56:38 +00:00
pablonyx
a98dcbc7de Update tenant logic (#4122)
* k

* k

* k

* quick nit

* nit
2025-02-26 03:53:46 +00:00
pablonyx
6f389dc100 Improve lengthy chats (#4126)
* remove scroll

* working well

* nit

* k

* nit
2025-02-26 03:22:21 +00:00
pablonyx
d56177958f fix email headers (#4100) 2025-02-26 03:12:30 +00:00
Kaveen Jayamanna
0e42ae9024 Content of .xlsl are not properly read during indexing. (#4035) 2025-02-25 21:10:47 -08:00
Weves
ce2b4de245 temp remove 2025-02-25 20:46:55 -08:00
Chris Weaver
a515aa78d2 Fix confluence test (#4130) 2025-02-26 03:03:54 +00:00
Weves
23073d91b9 reduce number of chars to index for search 2025-02-25 19:27:50 -08:00
Chris Weaver
f767b1f476 Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale

* Remove line

* Better log message

* Adjust log
2025-02-25 19:22:52 -08:00
pablonyx
9ffc8cb2c4 k 2025-02-25 18:15:49 -08:00
pablonyx
98bfb58147 Handle bad slack configurations– multi tenant (#4118)
* k

* quick nit

* k

* k
2025-02-25 22:22:54 +00:00
evan-danswer
6ce810e957 faster indexing status at scale plus minor cleanups (#4081)
* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
2025-02-25 21:22:26 +00:00
pablonyx
07b0b57b31 (nit) bump timeout 2025-02-25 14:10:30 -08:00
pablonyx
118cdd7701 Chat search (#4113)
* add chat search

* don't add the bible

* base functional

* k

* k

* functioning

* functioning well

* functioning well

* k

* delete bible

* quick cleanup

* quick cleanup

* k

* fixed frontend hooks

* delete bible

* nit

* nit

* nit

* fix build

* k

* improved debouncing

* address comments

* fix alembic

* k
2025-02-25 20:49:46 +00:00
rkuo-danswer
ac83b4c365 validate connector deletion (#4108)
* validate connector deletion

* fixes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 20:35:21 +00:00
142 changed files with 3744 additions and 1470 deletions

View File

@@ -53,24 +53,90 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
if: always()
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@0.29.0
# with:
# sarif_file: trivy-results.sarif
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -0,0 +1,84 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -0,0 +1,32 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
# op.execute(
# """
# CREATE INDEX idx_chat_message_message_lower
# ON chat_message (LOWER(substring(message, 1, 1500)))
# """
# )
pass
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -0,0 +1,36 @@
"""force lowercase all users
Revision ID: f11b408e39d3
Revises: 3bd4c84fe72f
Create Date: 2025-02-26 17:04:55.683500
"""
# revision identifiers, used by Alembic.
revision = "f11b408e39d3"
down_revision = "3bd4c84fe72f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing user emails to lowercase
from alembic import op
op.execute(
"""
UPDATE "user"
SET email = LOWER(email)
"""
)
# 2) Add a check constraint to ensure emails are always lowercase
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
def downgrade() -> None:
# Drop the check constraint
from alembic import op
op.drop_constraint("ensure_lowercase_email", "user", type_="check")

View File

@@ -0,0 +1,42 @@
"""lowercase multi-tenant user auth
Revision ID: 34e3630c7f32
Revises: a4f6ee863c47
Create Date: 2025-02-26 15:03:01.211894
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "34e3630c7f32"
down_revision = "a4f6ee863c47"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing rows to lowercase
op.execute(
"""
UPDATE user_tenant_mapping
SET email = LOWER(email)
"""
)
# 2) Add a check constraint so that emails cannot be written in uppercase
op.create_check_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
"email = LOWER(email)",
schema="public",
)
def downgrade() -> None:
# Drop the check constraint
op.drop_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
schema="public",
type_="check",
)

View File

@@ -5,11 +5,9 @@ from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -18,10 +16,8 @@ logger = setup_logger()
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -35,24 +31,19 @@ def perform_ttl_management_task(
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
def check_ttl_management_task(*, tenant_id: str) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@celery_app.task(
@@ -60,9 +51,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -18,7 +18,7 @@ logger = setup_logger()
def monitor_usergroup_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")

View File

@@ -2,6 +2,7 @@ import csv
import io
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
def fetch_and_process_chat_session_history(
db_session: Session,
@@ -107,6 +112,17 @@ def get_user_chat_sessions(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
# this is a direct query on the user id
if ONYX_QUERY_HISTORY_TYPE in [
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Per user query history has been disabled by the administrator.",
)
try:
chat_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session, limit=0
@@ -122,6 +138,7 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
@@ -141,6 +158,12 @@ def get_chat_session_history(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
@@ -157,11 +180,16 @@ def get_chat_session_history(
feedback_filter=feedback_type,
)
minimal_chat_sessions: list[ChatSessionMinimal] = []
for chat_session in page_of_chat_sessions:
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
minimal_chat_sessions.append(minimal_chat_session)
return PaginatedReturn(
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
items=minimal_chat_sessions,
total_items=total_filtered_chat_sessions_count,
)
@@ -172,6 +200,12 @@ def get_chat_session_admin(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
try:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
@@ -193,6 +227,9 @@ def get_chat_session_admin(
f"Could not create snapshot for chat session with id '{chat_session_id}'",
)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
return snapshot
@@ -203,6 +240,12 @@ def get_query_history_as_csv(
end: datetime | None = None,
db_session: Session = Depends(get_session),
) -> StreamingResponse:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
@@ -213,6 +256,9 @@ def get_query_history_as_csv(
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)

View File

@@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> BillingInformation:
def fetch_billing_information(
tenant_id: str,
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = BillingInformation(**response.json())
return billing_info
response_data = response.json()
# Check if the response indicates no subscription
if (
isinstance(response_data, dict)
and "subscribed" in response_data
and not response_data["subscribed"]
):
return SubscriptionStatusResponse(**response_data)
# Otherwise, parse as BillingInformation
return BillingInformation(**response_data)
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:

View File

@@ -200,25 +200,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
def configure_default_api_keys(db_session: Session) -> None:
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
@@ -227,6 +208,7 @@ def configure_default_api_keys(db_session: Session) -> None:
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@@ -238,6 +220,26 @@ def configure_default_api_keys(db_session: Session) -> None:
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,

View File

@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
from shared_configs.configs import MULTI_TENANT
_API_KEY_HEADER_NAME = "Authorization"
@@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel):
def generate_api_key(tenant_id: str | None = None) -> str:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
if not MULTI_TENANT or not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID

View File

@@ -2,6 +2,8 @@ import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -13,6 +15,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
@@ -150,8 +153,9 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
@@ -173,7 +177,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>Were sorry to see you go.</p>"
"<p>We're sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
@@ -239,13 +243,13 @@ def send_user_email_invite(
def send_forgot_password_email(
user_email: str,
token: str,
tenant_id: str,
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if tenant_id:
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)

View File

@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User
user: User | None = None
try:
# Attempt to get user by OAuth account
@@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.get_by_email(account_email)
user = await self.user_db.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Make sure user is not None before adding OAuth account
if user is not None:
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
else:
# This shouldn't happen since get_by_email would raise UserNotExists
# but adding as a safeguard
raise exceptions.UserNotExists()
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
@@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
# Add OAuth account only if user creation was successful
if user is not None:
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
raise HTTPException(
status_code=500, detail="Failed to create user account"
)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# User exists, update OAuth account if needed
if user is not None: # Add explicit check
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# Ensure user is not None before proceeding
if user is None:
raise HTTPException(
status_code=500, detail="Failed to authenticate or create user"
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
@@ -553,7 +568,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async_return_default_schema,
)(email=user.email)
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None

View File

@@ -2,6 +2,7 @@ import logging
import multiprocessing
import time
from typing import Any
from typing import cast
import sentry_sdk
from celery import Task
@@ -131,9 +132,9 @@ def on_task_postrun(
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
tenant_id = kwargs.get("tenant_id")
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "

View File

@@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -34,7 +34,7 @@ def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
tenant_id: str,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
tenant_id: str,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id

View File

@@ -8,16 +8,21 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -104,11 +109,10 @@ def revoke_tasks_blocking_deletion(
trail=False,
bind=True,
)
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -120,6 +124,21 @@ def check_for_connector_deletion_task(
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
@@ -203,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@@ -243,6 +262,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -323,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
tenant_id: str, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -475,3 +495,171 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -221,7 +221,7 @@ def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
@@ -320,7 +320,7 @@ def try_creating_permissions_sync_task(
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
@@ -410,7 +410,6 @@ def connector_permission_sync_generator_task(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
tenant_id,
enforce_creation=False,
)
if not created:
@@ -510,7 +509,7 @@ def connector_permission_sync_generator_task(
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
tenant_id: str,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
@@ -585,7 +584,7 @@ def update_external_document_permissions_task(
def validate_permission_sync_fences(
tenant_id: str | None,
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -632,7 +631,7 @@ def validate_permission_sync_fences(
def validate_permission_sync_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
@@ -842,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)

View File

@@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
@@ -220,7 +220,7 @@ def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> str | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
@@ -306,7 +306,7 @@ def try_creating_external_group_sync_task(
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""
External group sync task for a given connector credential pair
@@ -392,7 +392,6 @@ def connector_external_group_sync_generator_task(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
tenant_id,
enforce_creation=False,
)
if not created:
@@ -494,7 +493,7 @@ def connector_external_group_sync_generator_task(
def validate_external_group_sync_fences(
tenant_id: str | None,
tenant_id: str,
celery_app: Celery,
r: Redis,
r_replica: Redis,
@@ -526,7 +525,7 @@ def validate_external_group_sync_fences(
def validate_external_group_sync_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,

View File

@@ -182,7 +182,7 @@ class SimpleJobResult:
class ConnectorIndexingContext(BaseModel):
tenant_id: str | None
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
@@ -598,7 +598,7 @@ def connector_indexing_task(
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -890,7 +890,7 @@ def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
@@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)

View File

@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
def validate_indexing_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
@@ -311,7 +311,7 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str | None,
tenant_id: str,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
@@ -442,7 +442,7 @@ def try_creating_indexing_task(
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.

View File

@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")

View File

@@ -91,7 +91,7 @@ class Metric(BaseModel):
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str | None) -> None:
def emit(self, tenant_id: str) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
@@ -656,7 +656,7 @@ def build_job_id(
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
return monitor_celery_queues_helper(self)

View File

@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up

View File

@@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str | None,
tenant_id: str,
) -> str | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -333,7 +333,7 @@ def connector_pruning_generator_task(
cc_pair_id: int,
connector_id: int,
credential_id: int,
tenant_id: str | None,
tenant_id: str,
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -521,7 +521,7 @@ def connector_pruning_generator_task(
def monitor_ccpair_pruning_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
def validate_pruning_fences(
tenant_id: str | None,
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -615,7 +615,7 @@ def validate_pruning_fences(
def validate_pruning_fence(
tenant_id: str | None,
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],

View File

@@ -32,7 +32,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:

View File

@@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str | None,
tenant_id: str,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@@ -297,7 +297,8 @@ def cloud_beat_task_generator(
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] | list[None] = []
tenant_ids: list[str] = []
num_processed_tenants = 0
try:
tenant_ids = get_all_tenant_ids()
@@ -325,6 +326,8 @@ def cloud_beat_task_generator(
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -344,6 +347,7 @@ def cloud_beat_task_generator(
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -76,7 +76,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
@@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
lock_beat.reacquire()
@@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> int | None:
lock_beat.reacquire()
@@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -523,9 +523,7 @@ def monitor_document_set_taskset(
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def vespa_metadata_sync_task(
self: Task, document_id: str, *, tenant_id: str | None
) -> bool:
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED

View File

@@ -16,7 +16,7 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.setup import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -55,6 +55,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -67,7 +68,6 @@ def _get_connector_runner(
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
) -> ConnectorRunner:
"""
@@ -86,7 +86,6 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
# validate the connector settings
@@ -241,7 +240,7 @@ def _check_failure_threshold(
def _run_indexing(
db_session: Session,
index_attempt_id: int,
tenant_id: str | None,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
@@ -388,7 +387,6 @@ def _run_indexing(
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
# don't use a checkpoint if we're explicitly indexing from
@@ -681,7 +679,7 @@ def _run_indexing(
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str | None,
tenant_id: str,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
@@ -701,7 +699,7 @@ def run_indexing_entrypoint(
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
if MULTI_TENANT:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name

View File

@@ -6,6 +6,7 @@ from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
#####
@@ -29,6 +30,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
)
#####
# Web Configs

View File

@@ -213,6 +213,12 @@ class AuthType(str, Enum):
CLOUD = "cloud"
class QueryHistoryType(str, Enum):
DISABLED = "disabled"
ANONYMIZED = "anonymized"
NORMAL = "normal"
# Special characters for password validation
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
@@ -342,6 +348,9 @@ class OnyxRedisSignals:
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
"signal:block_validate_connector_deletion_fences"
)
class OnyxRedisConstants:

View File

@@ -11,6 +11,8 @@ from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.utils.logger import setup_logger
@@ -161,7 +163,7 @@ class OnyxConfluence(Confluence):
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -236,9 +238,41 @@ class OnyxConfluence(Confluence):
raise e
# yield the results individually
yield from next_response.get("results", [])
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
url_suffix = next_response.get("_links", {}).get("next")
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
def paginated_cql_retrieval(
self,
@@ -298,7 +332,9 @@ class OnyxConfluence(Confluence):
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
# Example response:
# {
# 'user': {

View File

@@ -2,7 +2,10 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
logger = setup_logger()
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
def attachment_to_content(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
return parse_qs(parsed_url.query).get(param, [None])[0]
def get_start_param_from_url(url: str) -> int:
"""Get the start parameter from a url"""
start_str = get_single_param_from_url(url, "start")
if start_str is None:
return 0
return int(start_str)
def update_param_in_path(path: str, param: str, value: str) -> str:
"""Update a parameter in a path. Path should look something like:
/api/rest/users?start=0&limit=10
"""
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@@ -164,13 +163,9 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)
@@ -184,7 +179,6 @@ def validate_ccpair_for_user(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None,
enforce_creation: bool = True,
) -> bool:
if INTEGRATION_TESTS_MODE:
@@ -216,7 +210,6 @@ def validate_ccpair_for_user(
input_type=connector.input_type,
connector_specific_config=connector.connector_specific_config,
credential=credential,
tenant_id=tenant_id,
)
except ConnectorValidationError as e:
raise e

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -27,8 +27,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -165,12 +163,10 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -179,9 +175,8 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
@@ -203,8 +198,6 @@ class LocalFileConnector(LoadConnector):
if documents:
yield documents
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])

View File

@@ -1,7 +1,9 @@
import io
from datetime import datetime
from datetime import timezone
from tempfile import NamedTemporaryFile
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
@@ -43,12 +45,15 @@ def _extract_sections_basic(
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]
supported_file_types = set(item.value for item in GDriveMimeType)
if mime_type not in set(item.value for item in GDriveMimeType):
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
@@ -109,7 +114,53 @@ def _extract_sections_basic(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -128,6 +179,8 @@ def _extract_sections_basic(
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
@@ -141,6 +194,8 @@ def _extract_sections_basic(
.decode("utf-8"),
)
]
# ---------------------------
# Word, PowerPoint, PDF files
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
@@ -170,7 +225,11 @@ def _extract_sections_basic(
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]

View File

@@ -5,6 +5,10 @@ from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
SPREADSHEET_OPEN_FORMAT = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"

View File

@@ -16,7 +16,6 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -73,7 +72,7 @@ def insert_api_key(
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = get_current_tenant_id()
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key = generate_api_key(tenant_id)
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER

View File

@@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
stmt = stmt.order_by(desc(ChatSession.time_updated))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@@ -962,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
is_agentic=chat_message.is_agentic,
error=chat_message.error,
)

View File

@@ -0,0 +1,111 @@
from typing import List
from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import column
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
def search_chat_sessions(
user_id: UUID | None,
db_session: Session,
query: Optional[str] = None,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False,
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Fast full-text search on ChatSession + ChatMessage using tsvectors.
If no query is provided, returns the most recent chat sessions.
Otherwise, searches both chat messages and session descriptions.
Returns a tuple of (sessions, has_more) where has_more indicates if
there are additional results beyond the requested page.
"""
offset_val = (page - 1) * page_size
# If no query, just return the most recent sessions
if not query or not query.strip():
stmt = (
select(ChatSession)
.order_by(desc(ChatSession.time_created))
.offset(offset_val)
.limit(page_size + 1)
)
if user_id is not None:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
sessions = result.scalars().all()
has_more = len(sessions) > page_size
if has_more:
sessions = sessions[:page_size]
return list(sessions), has_more
# Otherwise, proceed with full-text search
query = query.strip()
base_conditions = []
if user_id is not None:
base_conditions.append(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
base_conditions.append(ChatSession.deleted.is_(False))
message_tsv: ColumnClause = column("message_tsv")
description_tsv: ColumnClause = column("description_tsv")
ts_query = func.plainto_tsquery("english", query)
description_session_ids = (
select(ChatSession.id)
.where(*base_conditions)
.where(description_tsv.op("@@")(ts_query))
)
message_session_ids = (
select(ChatMessage.chat_session_id)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(*base_conditions)
.where(message_tsv.op("@@")(ts_query))
)
combined_ids = description_session_ids.union(message_session_ids).alias(
"combined_ids"
)
final_stmt = (
select(ChatSession)
.join(combined_ids, ChatSession.id == combined_ids.c.id)
.order_by(desc(ChatSession.time_created))
.distinct()
.offset(offset_val)
.limit(page_size + 1)
.options(joinedload(ChatSession.persona))
)
session_objs = db_session.execute(final_stmt).scalars().all()
has_more = len(session_objs) > page_size
if has_more:
session_objs = session_objs[:page_size]
return list(session_objs), has_more

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@@ -8,15 +9,18 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -31,10 +35,12 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -98,17 +104,52 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
def get_connector_credential_pairs(
@@ -151,6 +192,16 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,

View File

@@ -360,18 +360,13 @@ def backend_update_credential_json(
db_session.commit()
def delete_credential(
def _delete_credential_internal(
credential: Credential,
credential_id: int,
user: User | None,
db_session: Session,
force: bool = False,
) -> None:
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
"""Internal utility function to handle the actual deletion of a credential"""
associated_connectors = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.credential_id == credential_id)
@@ -416,6 +411,35 @@ def delete_credential(
db_session.commit()
def delete_credential_for_user(
credential_id: int,
user: User,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential that belongs to a specific user"""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
_delete_credential_internal(credential, credential_id, db_session, force)
def delete_credential(
credential_id: int,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential regardless of ownership (admin function)"""
credential = fetch_credential_by_id(credential_id, db_session)
if credential is None:
raise ValueError(f"Credential by provided id {credential_id} does not exist")
_delete_credential_internal(credential, credential_id, db_session, force)
def create_initial_public_credential(db_session: Session) -> None:
error_msg = (
"DB is not in a valid initial state."

View File

@@ -24,6 +24,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -229,12 +230,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
stmt = (
select(
@@ -260,6 +261,16 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -218,6 +218,7 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
@@ -257,11 +258,11 @@ class SqlEngine:
cls._engine = None
def get_all_tenant_ids() -> list[str] | list[None]:
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
if not MULTI_TENANT:
return [None]
return [POSTGRES_DEFAULT_SCHEMA]
with get_session_with_shared_schema() as session:
result = session.execute(
@@ -416,7 +417,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
@contextmanager
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
"""

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@@ -9,9 +10,13 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -368,19 +373,33 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@@ -395,7 +414,53 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
return db_session.execute(stmt).scalars().all()
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def count_index_attempts_for_connector(
@@ -453,37 +518,12 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return list(db_session.execute(stmt).scalars().unique().all())
def get_index_attempts_for_cc_pair(

View File

@@ -7,6 +7,7 @@ from typing import Optional
from uuid import uuid4
from pydantic import BaseModel
from sqlalchemy.orm import validates
from typing_extensions import TypedDict # noreorder
from uuid import UUID
@@ -25,6 +26,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Text
@@ -205,6 +207,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
@property
def password_configured(self) -> bool:
"""
@@ -2269,6 +2275,10 @@ class UserTenantMapping(Base):
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):

View File

@@ -100,9 +100,14 @@ def _add_user_filters(
.correlate(Persona)
)
else:
where_clause |= Persona.is_public == True # noqa: E712
where_clause &= Persona.is_visible == True # noqa: E712
# Group the public persona conditions
public_condition = (Persona.is_public == True) & ( # noqa: E712
Persona.is_visible == True # noqa: E712
)
where_clause |= public_condition
where_clause |= Persona__User.user_id == user.id
where_clause |= Persona.user_id == user.id
return stmt.where(where_clause)

View File

@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
# Vespa's Document API.
def get_document_chunk_ids(
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
tenant_id: str | None,
tenant_id: str,
large_chunks_enabled: bool,
) -> list[UUID]:
doc_chunk_ids = []
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
*,
document_id: str,
chunk_id: int,
tenant_id: str | None,
tenant_id: str,
large_chunk_id: int | None = None,
) -> UUID:
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
)
unique_identifier_string = "_".join([doc_str, chunk_index])
if tenant_id and MULTI_TENANT:
if MULTI_TENANT:
unique_identifier_string += "_" + tenant_id
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)

View File

@@ -43,7 +43,7 @@ class IndexBatchParams:
doc_id_to_previous_chunk_cnt: dict[str, int | None]
doc_id_to_new_chunk_cnt: dict[str, int]
tenant_id: str | None
tenant_id: str
large_chunks_enabled: bool
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
) -> int:
"""
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:
@@ -270,9 +270,7 @@ class Updatable(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def update(
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
) -> None:
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
"""
Updates some set of chunks. The document and fields to update are specified in the update
requests. Each update request in the list applies its changes to a list of document ids.

View File

@@ -468,9 +468,7 @@ class VespaIndex(DocumentIndex):
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
def update(
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
) -> None:
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
# Handle Vespa character limitations
@@ -618,7 +616,7 @@ class VespaIndex(DocumentIndex):
doc_id: str,
*,
chunk_count: int | None,
tenant_id: str | None,
tenant_id: str,
fields: VespaDocumentFields,
) -> int:
"""Note: if the document id does not exist, the update will be a no-op and the
@@ -661,7 +659,7 @@ class VespaIndex(DocumentIndex):
self,
doc_id: str,
*,
tenant_id: str | None,
tenant_id: str,
chunk_count: int | None,
) -> int:
total_chunks_deleted = 0

View File

@@ -158,8 +158,8 @@ def index_doc_batch_with_handler(
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
tenant_id: str,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> IndexingPipelineResult:
try:
index_pipeline_result = index_doc_batch(
@@ -317,8 +317,8 @@ def index_doc_batch(
document_index: DocumentIndex,
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
tenant_id: str,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> IndexingPipelineResult:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
@@ -525,9 +525,9 @@ def build_indexing_pipeline(
embedder: IndexingEmbedder,
document_index: DocumentIndex,
db_session: Session,
tenant_id: str,
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""

View File

@@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
negative -> ranked lower.
"""
tenant_id: str | None = None
tenant_id: str
access: "DocumentAccess"
document_sets: set[str]
boost: int
@@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
tenant_id: str | None,
tenant_id: str,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
return cls(

View File

@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:
setup_onyx(db_session, None)
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
else:
setup_multitenant_onyx()

View File

@@ -23,7 +23,7 @@ from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_session_by_message_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChannelConfig
from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
@@ -410,12 +410,11 @@ def _build_qa_response_blocks(
def _build_continue_in_web_ui_block(
tenant_id: str | None,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
@@ -482,7 +481,6 @@ def build_follow_up_resolved_blocks(
def build_slack_response_blocks(
answer: ChatOnyxBotResponse,
tenant_id: str | None,
message_info: SlackMessageInfo,
channel_conf: ChannelConfig | None,
use_citations: bool,
@@ -517,7 +515,6 @@ def build_slack_response_blocks(
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)

View File

@@ -11,7 +11,7 @@ from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
@@ -114,7 +114,7 @@ def handle_generate_answer_button(
thread_ts=thread_ts,
)
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
@@ -136,7 +136,6 @@ def handle_generate_answer_button(
slack_channel_config=slack_channel_config,
receiver_ids=None,
client=client.web_client,
tenant_id=client.tenant_id,
channel=channel_id,
logger=logger,
feedback_reminder_id=None,
@@ -151,11 +150,10 @@ def handle_slack_feedback(
user_id_to_post_confirmation: str,
channel_id_to_post_confirmation: str,
thread_ts_to_post_confirmation: str,
tenant_id: str | None,
) -> None:
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
@@ -246,7 +244,7 @@ def handle_followup_button(
tag_ids: list[str] = []
group_ids: list[str] = []
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)

View File

@@ -5,7 +5,7 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.onyxbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.users import add_slack_user_if_not_exists
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
@@ -109,7 +109,6 @@ def handle_message(
slack_channel_config: SlackChannelConfig,
client: WebClient,
feedback_reminder_id: str | None,
tenant_id: str | None,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -135,9 +134,7 @@ def handle_message(
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
)
slack_usage_report(action=action, sender_id=sender_id, client=client)
document_set_names: list[str] | None = None
persona = slack_channel_config.persona if slack_channel_config else None
@@ -218,7 +215,7 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
if message_info.email:
add_slack_user_if_not_exists(db_session, message_info.email)
@@ -244,6 +241,5 @@ def handle_message(
channel=channel,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
tenant_id=tenant_id,
)
return issue_with_regular_answer

View File

@@ -24,7 +24,6 @@ from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import RetrievalDetails
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
@@ -72,7 +71,6 @@ def handle_regular_answer(
channel: str,
logger: OnyxLoggingAdapter,
feedback_reminder_id: str | None,
tenant_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
@@ -87,7 +85,7 @@ def handle_regular_answer(
user = None
if message_info.is_bot_dm:
if message_info.email:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
@@ -96,7 +94,7 @@ def handle_regular_answer(
# This way slack flow always has a persona
persona = slack_channel_config.persona
if not persona:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
document_set_names = [
document_set.name for document_set in persona.document_sets
@@ -157,7 +155,7 @@ def handle_regular_answer(
def _get_slack_answer(
new_message_request: CreateChatMessageRequest, onyx_user: User | None
) -> ChatOnyxBotResponse:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
packets = stream_chat_message_objects(
new_msg_req=new_message_request,
user=onyx_user,
@@ -197,7 +195,7 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
answer_request = prepare_chat_message_request(
message_text=user_message.message,
user=user,
@@ -361,7 +359,6 @@ def handle_regular_answer(
return True
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
channel_conf=channel_conf,

View File

@@ -17,10 +17,12 @@ from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.chat.models import ThreadMessage
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import POD_NAME
@@ -35,6 +37,7 @@ from onyx.context.search.retrieval.search_runner import (
download_nltk_data,
)
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackBot
from onyx.db.search_settings import get_current_search_settings
@@ -90,6 +93,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -121,13 +125,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
self.tenant_ids: Set[str] = set()
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
# Store Redis lock objects here so we can release them properly
self.redis_locks: Dict[str | None, Lock] = {}
self.redis_locks: Dict[str, Lock] = {}
self.running = True
self.pod_id = self.get_pod_id()
@@ -191,7 +195,7 @@ class SlackbotHandler:
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
self, db_session: Session, tenant_id: str, bot: SlackBot
) -> None:
"""
- If the tokens are missing or empty, close the socket client and remove them.
@@ -249,7 +253,12 @@ class SlackbotHandler:
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
all_tenants = [
tenant_id
for tenant_id in get_all_tenant_ids()
if tenant_id not in get_gated_tenants()
]
token: Token[str | None]
@@ -340,7 +349,7 @@ class SlackbotHandler:
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
@@ -378,7 +387,7 @@ class SlackbotHandler:
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _remove_tenant(self, tenant_id: str | None) -> None:
def _remove_tenant(self, tenant_id: str) -> None:
"""
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
@@ -408,7 +417,7 @@ class SlackbotHandler:
)
def start_socket_client(
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
) -> None:
socket_client: TenantSocketModeClient = _get_socket_client(
slack_bot_tokens, tenant_id, slack_bot_id
@@ -416,6 +425,7 @@ class SlackbotHandler:
try:
bot_info = socket_client.web_client.auth_test()
if bot_info["ok"]:
bot_user_id = bot_info["user_id"]
user_info = socket_client.web_client.users_info(user=bot_user_id)
@@ -426,9 +436,23 @@ class SlackbotHandler:
logger.info(
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
)
except SlackApiError as e:
# Only error out if we get a not_authed error
if "not_authed" in str(e):
self.tenant_ids.add(tenant_id)
logger.error(
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
"Error: {e}"
)
return
# Log other Slack API errors but continue
logger.error(
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
except Exception as e:
logger.warning(
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
# Log other exceptions but continue
logger.error(
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
# Append the event handler
@@ -564,7 +588,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
@@ -658,7 +682,6 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
user_id_to_post_confirmation=user_id,
channel_id_to_post_confirmation=channel_id,
thread_ts_to_post_confirmation=thread_ts,
tenant_id=client.tenant_id,
)
query_event_id, _, _ = decompose_action_id(feedback_id)
@@ -774,8 +797,9 @@ def process_message(
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
tenant_id = get_current_tenant_id()
logger.debug(
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
)
# Throw out requests that can't or shouldn't be handled
@@ -788,50 +812,39 @@ def process_message(
client=client.web_client, channel_id=channel
)
token: Token[str | None] | None = None
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
try:
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
with get_session_with_current_tenant() as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
follow_up = bool(
slack_channel_config.channel_config
and slack_channel_config.channel_config.get("follow_up_tags")
is not None
)
follow_up = bool(
slack_channel_config.channel_config
and slack_channel_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
failed = handle_message(
message_info=details,
slack_channel_config=slack_channel_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
failed = handle_message(
message_info=details,
slack_channel_config=slack_channel_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender_id,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender_id,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
@@ -890,7 +903,7 @@ def create_process_slack_event() -> (
def _get_socket_client(
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.onyx.app/slack_bot_setup

View File

@@ -4,6 +4,8 @@ import re
import string
import time
import uuid
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from typing import cast
@@ -30,7 +32,7 @@ from onyx.configs.onyxbot_configs import (
)
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
@@ -43,6 +45,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.text_processing import replace_whitespaces_w_space
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -569,9 +572,7 @@ def read_slack_thread(
return thread_messages
def slack_usage_report(
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
) -> None:
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
if DISABLE_TELEMETRY:
return
@@ -583,14 +584,13 @@ def slack_usage_report(
logger.warning("Unable to find sender email")
if sender_email is not None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
optional_telemetry(
record_type=RecordType.USAGE,
data={"action": action},
user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User",
tenant_id=tenant_id,
)
@@ -663,9 +663,30 @@ def get_feedback_visibility() -> FeedbackVisibility:
class TenantSocketModeClient(SocketModeClient):
def __init__(
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
):
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id
self._tenant_id = tenant_id
self.slack_bot_id = slack_bot_id
@contextmanager
def _set_tenant_context(self) -> Generator[None, None, None]:
token = None
try:
if self._tenant_id:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self._tenant_id)
yield
finally:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def enqueue_message(self, message: str) -> None:
with self._set_tenant_context():
super().enqueue_message(message)
def process_message(self) -> None:
with self._set_tenant_context():
super().process_message()
def run_message_listeners(self, message: dict, raw_message: str) -> None:
with self._set_tenant_context():
super().run_message_listeners(message, raw_message)

View File

@@ -16,10 +16,10 @@ class RedisConnector:
"""Composes several classes to simplify interacting with a connector and its
associated background tasks / associated redis interactions."""
def __init__(self, tenant_id: str | None, id: int) -> None:
def __init__(self, tenant_id: str, id: int) -> None:
"""id: a connector credential pair id"""
self.tenant_id: str | None = tenant_id
self.tenant_id: str = tenant_id
self.id: int = id
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)

View File

@@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
PREFIX = "connectorsync"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
# documents that should be skipped
@@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> tuple[int, int] | None:
"""We can limit the number of tasks generated here, which is useful to prevent
one tenant from overwhelming the sync queue.

View File

@@ -33,14 +33,22 @@ class RedisConnectorDelete:
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -77,6 +85,20 @@ class RedisConnectorDelete:
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
@@ -141,6 +163,7 @@ class RedisConnectorDelete:
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@@ -153,6 +176,9 @@ class RedisConnectorDelete:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -52,8 +52,8 @@ class RedisConnectorPermissionSync:
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.redis = redis

View File

@@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync:
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.redis = redis

View File

@@ -52,12 +52,12 @@ class RedisConnectorIndex:
def __init__(
self,
tenant_id: str | None,
tenant_id: str,
id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str | None = tenant_id
self.tenant_id: str = tenant_id
self.id = id
self.search_settings_id = search_settings_id
self.redis = redis
@@ -93,10 +93,7 @@ class RedisConnectorIndex:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorIndexPayload | None:
@@ -106,9 +103,7 @@ class RedisConnectorIndex:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
def set_fence(
self,
@@ -123,10 +118,7 @@ class RedisConnectorIndex:
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
@@ -146,10 +138,7 @@ class RedisConnectorIndex:
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
return bool(self.redis.exists(self.watchdog_key))
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -160,10 +149,7 @@ class RedisConnectorIndex:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def set_connector_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -180,10 +166,7 @@ class RedisConnectorIndex:
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
return bool(self.redis.exists(self.generator_lock_key))
def set_generator_complete(self, payload: int | None) -> None:
if not payload:

View File

@@ -52,8 +52,8 @@ class RedisConnectorPrune:
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.redis = redis

View File

@@ -13,8 +13,8 @@ class RedisConnectorStop:
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
TIMEOUT_TTL = 300
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
self.tenant_id: str = tenant_id
self.id: int = id
self.redis = redis

View File

@@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
@property
@@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> tuple[int, int] | None:
"""Max tasks is ignored for now until we can build the logic to mark the
document set up to date over multiple batches.

View File

@@ -14,8 +14,8 @@ class RedisObjectHelper(ABC):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: str):
self._tenant_id: str | None = tenant_id
def __init__(self, tenant_id: str, id: str):
self._tenant_id: str = tenant_id
self._id: str = id
self.redis = get_redis_client(tenant_id=tenant_id)
@@ -87,7 +87,7 @@ class RedisObjectHelper(ABC):
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> tuple[int, int] | None:
"""First element should be the number of actual tasks generated, second should
be the number of docs that were candidates to be synced for the cc pair.

View File

@@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, tenant_id: str | None, id: int) -> None:
def __init__(self, tenant_id: str, id: int) -> None:
super().__init__(tenant_id, str(id))
@property
@@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper):
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
tenant_id: str,
) -> tuple[int, int] | None:
"""Max tasks is ignored for now until we can build the logic to mark the
user group up to date over multiple batches.

View File

@@ -37,13 +37,15 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import ConnectorBase
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
def _create_indexable_chunks(
preprocessed_docs: list[dict],
tenant_id: str | None,
tenant_id: str,
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
ids_to_documents = {}
chunks = []
@@ -86,7 +88,7 @@ def _create_indexable_chunks(
mini_chunk_embeddings=[],
),
title_embedding=preprocessed_doc["title_embedding"],
tenant_id=tenant_id,
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
access=default_public_access,
document_sets=set(),
boost=DEFAULT_BOOST,
@@ -111,7 +113,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]:
def seed_initial_documents(
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
db_session: Session, tenant_id: str, cohere_enabled: bool = False
) -> None:
"""
Seed initial documents so users don't have an empty index to start

View File

@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
)
is_editable_for_current_user = editable_cc_pair is not None
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
document_count_info_list = list(
get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=[cc_pair_identifier],
cc_pairs=[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
],
)
)
documents_indexed = (
@@ -620,7 +620,7 @@ def associate_credential_to_connector(
)
try:
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
validate_ccpair_for_user(connector_id, credential_id, db_session)
response = add_credential_to_connector(
db_session=db_session,
@@ -646,7 +646,6 @@ def associate_credential_to_connector(
)
return response
except ValidationError as e:
# If validation fails, delete the connector and commit the changes
# Ensures we don't leave invalid connectors in the database
@@ -660,10 +659,14 @@ def associate_credential_to_connector(
)
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")
delete_connector(db_session, connector_id)
db_session.commit()
raise HTTPException(status_code=400, detail="Name must be unique")
except Exception as e:
logger.exception(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="Unexpected error")

View File

@@ -72,25 +72,31 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector import update_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_latest_index_attempts
from onyx.db.index_attempt import get_latest_index_attempts_by_status
from onyx.db.index_attempt import get_latest_index_attempts_parallel
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
@@ -119,8 +125,8 @@ from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -578,6 +584,8 @@ def get_connector_status(
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
eager_load_connector=True,
eager_load_credential=True,
)
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
@@ -632,23 +640,35 @@ def get_connector_indexing_status(
# Additional checks are done to make sure the connector and credential still exist.
# TODO: make this one query ... possibly eager load or wrap in a read transaction
# to avoid the complexity of trying to error check throughout the function
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
)
cc_pair_identifiers = [
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
for cc_pair in cc_pairs
]
latest_index_attempts = get_latest_index_attempts(
secondary_index=secondary_index,
db_session=db_session,
# see https://stackoverflow.com/questions/75758327/
# sqlalchemy-method-connection-for-bind-is-already-in-progress
# for why we can't pass in the current db_session to these functions
(
cc_pairs,
latest_index_attempts,
latest_finished_index_attempts,
) = run_functions_tuples_in_parallel(
[
(
# Gets the connector/credential pairs for the user
get_connector_credential_pairs_for_user_parallel,
(user, get_editable, None, True, True, True),
),
(
# Gets the most recent index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, False),
),
(
# Gets the most recent FINISHED index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, True),
),
]
)
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
cc_pair_to_latest_index_attempt = {
(
@@ -658,31 +678,60 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
cc_pair_to_latest_finished_index_attempt = {
(
index_attempt.connector_credential_pair.connector_id,
index_attempt.connector_credential_pair.credential_id,
): index_attempt
for index_attempt in latest_finished_index_attempts
}
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
[
(
get_document_counts_for_cc_pairs_parallel,
(
[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
for cc_pair in cc_pairs
],
),
),
(
get_cc_pair_groups_for_ids_parallel,
([cc_pair.id for cc_pair in cc_pairs],),
),
]
)
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
group_cc_pair_relationships = cast(
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
)
cc_pair_to_document_cnt = {
(connector_id, credential_id): cnt
for connector_id, credential_id, cnt in document_count_info
}
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
db_session=db_session,
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
)
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
for relationship in group_cc_pair_relationships:
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
connector_to_cc_pair_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
get_search_settings = (
get_secondary_search_settings
if secondary_index
else get_current_search_settings
)
search_settings = get_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@@ -705,11 +754,8 @@ def get_connector_indexing_status(
(connector.id, credential.id)
)
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
db_session=db_session,
connector_credential_pair_id=cc_pair.id,
secondary_index=secondary_index,
only_finished=True,
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(connector.id, credential.id)
)
indexing_statuses.append(
@@ -718,7 +764,9 @@ def get_connector_indexing_status(
name=cc_pair.name,
in_progress=in_progress,
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
connector=ConnectorSnapshot.from_connector_db_model(
connector, connector_to_cc_pair_ids.get(connector.id, [])
),
credential=CredentialSnapshot.from_credential_db_model(credential),
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",
@@ -854,7 +902,6 @@ def create_connector_with_mock_credential(
connector_id=connector_id,
credential_id=credential_id,
db_session=db_session,
tenant_id=tenant_id,
)
response = add_credential_to_connector(
db_session=db_session,

View File

@@ -13,12 +13,12 @@ from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
from onyx.db.credentials import delete_credential
from onyx.db.credentials import delete_credential_for_user
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import fetch_credentials_by_source_for_user
from onyx.db.credentials import fetch_credentials_for_user
from onyx.db.credentials import swap_credentials_connector
from onyx.db.credentials import update_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import DocumentSource
from onyx.db.models import User
@@ -89,7 +89,7 @@ def delete_credential_by_id_admin(
db_session: Session = Depends(get_session),
) -> StatusResponse:
"""Same as the user endpoint, but can delete any credential (not just the user's own)"""
delete_credential(db_session=db_session, credential_id=credential_id, user=None)
delete_credential(db_session=db_session, credential_id=credential_id)
return StatusResponse(
success=True, message="Credential deleted successfully", data=credential_id
)
@@ -100,13 +100,11 @@ def swap_credentials_for_connector(
credential_swap_req: CredentialSwapRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse:
validate_ccpair_for_user(
credential_swap_req.connector_id,
credential_swap_req.new_credential_id,
db_session,
tenant_id,
)
connector_credential_pair = swap_credentials_connector(
@@ -245,7 +243,7 @@ def delete_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
delete_credential(
delete_credential_for_user(
credential_id,
user,
db_session,
@@ -262,7 +260,7 @@ def force_delete_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
delete_credential(credential_id, user, db_session, True)
delete_credential_for_user(credential_id, user, db_session, True)
return StatusResponse(
success=True, message="Credential deleted successfully", data=credential_id

View File

@@ -83,7 +83,9 @@ class ConnectorSnapshot(ConnectorBase):
source: DocumentSource
@classmethod
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
def from_connector_db_model(
cls, connector: Connector, credential_ids: list[int] | None = None
) -> "ConnectorSnapshot":
return ConnectorSnapshot(
id=connector.id,
name=connector.name,
@@ -92,9 +94,10 @@ class ConnectorSnapshot(ConnectorBase):
connector_specific_config=connector.connector_specific_config,
refresh_freq=connector.refresh_freq,
prune_freq=connector.prune_freq,
credential_ids=[
association.credential.id for association in connector.credentials
],
credential_ids=(
credential_ids
or [association.credential.id for association in connector.credentials]
),
indexing_start=connector.indexing_start,
time_created=connector.time_created,
time_updated=connector.time_updated,

View File

@@ -49,6 +49,7 @@ def get_folders(
name=chat_session.description,
persona_id=chat_session.persona_id,
time_created=chat_session.time_created.isoformat(),
time_updated=chat_session.time_updated.isoformat(),
shared_status=chat_session.shared_status,
folder_id=folder.id,
)

View File

@@ -147,9 +147,11 @@ def list_threads(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
current_temperature_override=chat.temperature_override,
)
for chat in chat_sessions
]

View File

@@ -1,15 +1,18 @@
import asyncio
import datetime
import io
import json
import os
import uuid
from collections.abc import Callable
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
@@ -44,6 +47,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
@@ -65,10 +69,13 @@ from onyx.secondary_llm_flows.chat_session_naming import (
from onyx.server.query_and_chat.models import ChatFeedbackRequest
from onyx.server.query_and_chat.models import ChatMessageIdentifier
from onyx.server.query_and_chat.models import ChatRenameRequest
from onyx.server.query_and_chat.models import ChatSearchResponse
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import ChatSessionDetailResponse
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionGroup
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.query_and_chat.models import ChatSessionSummary
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import CreateChatSessionID
@@ -112,6 +119,7 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
@@ -794,3 +802,84 @@ def fetch_chat_file(
file_io = file_store.read_file(file_id, mode="b")
return StreamingResponse(file_io, media_type=media_type)
@router.get("/search")
async def search_chats(
query: str | None = Query(None),
page: int = Query(1),
page_size: int = Query(10),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSearchResponse:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
"""
# Use the enhanced database function for chat search
chat_sessions, has_more = search_chat_sessions(
user_id=user.id if user else None,
db_session=db_session,
query=query,
page=page,
page_size=page_size,
include_deleted=False,
include_onyxbot_flows=False,
)
# Group chat sessions by time period
today = datetime.datetime.now().date()
yesterday = today - timedelta(days=1)
this_week = today - timedelta(days=7)
this_month = today - timedelta(days=30)
today_chats: list[ChatSessionSummary] = []
yesterday_chats: list[ChatSessionSummary] = []
this_week_chats: list[ChatSessionSummary] = []
this_month_chats: list[ChatSessionSummary] = []
older_chats: list[ChatSessionSummary] = []
for session in chat_sessions:
session_date = session.time_created.date()
chat_summary = ChatSessionSummary(
id=session.id,
name=session.description,
persona_id=session.persona_id,
time_created=session.time_created,
shared_status=session.shared_status,
folder_id=session.folder_id,
current_alternate_model=session.current_alternate_model,
current_temperature_override=session.temperature_override,
)
if session_date == today:
today_chats.append(chat_summary)
elif session_date == yesterday:
yesterday_chats.append(chat_summary)
elif session_date > this_week:
this_week_chats.append(chat_summary)
elif session_date > this_month:
this_month_chats.append(chat_summary)
else:
older_chats.append(chat_summary)
# Create groups
groups = []
if today_chats:
groups.append(ChatSessionGroup(title="Today", chats=today_chats))
if yesterday_chats:
groups.append(ChatSessionGroup(title="Yesterday", chats=yesterday_chats))
if this_week_chats:
groups.append(ChatSessionGroup(title="This Week", chats=this_week_chats))
if this_month_chats:
groups.append(ChatSessionGroup(title="This Month", chats=this_month_chats))
if older_chats:
groups.append(ChatSessionGroup(title="Older", chats=older_chats))
return ChatSearchResponse(
groups=groups,
has_more=has_more,
next_page=page + 1 if has_more else None,
)

View File

@@ -24,6 +24,7 @@ from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
pass
@@ -180,6 +181,7 @@ class ChatSessionDetails(BaseModel):
name: str | None
persona_id: int | None = None
time_created: str
time_updated: str
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
@@ -240,6 +242,7 @@ class ChatMessageDetail(BaseModel):
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None = None
is_agentic: bool | None = None
error: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
@@ -282,3 +285,35 @@ class AdminSearchRequest(BaseModel):
class AdminSearchResponse(BaseModel):
documents: list[SearchDoc]
class ChatSessionSummary(BaseModel):
id: UUID
name: str | None = None
persona_id: int | None = None
time_created: datetime
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
current_temperature_override: float | None = None
class ChatSessionGroup(BaseModel):
title: str
chats: list[ChatSessionSummary]
class ChatSearchResponse(BaseModel):
groups: list[ChatSessionGroup]
has_more: bool
next_page: int | None = None
class ChatSearchRequest(BaseModel):
query: str | None = None
page: int = 1
page_size: int = 10
class CreateChatResponse(BaseModel):
chat_session_id: str

View File

@@ -159,6 +159,7 @@ def get_user_search_sessions(
name=sessions_with_documents_dict[search.id],
persona_id=search.persona_id,
time_created=search.time_created.isoformat(),
time_updated=search.time_updated.isoformat(),
shared_status=search.shared_status,
folder_id=search.folder_id,
current_alternate_model=search.current_alternate_model,

View File

@@ -4,7 +4,9 @@ from enum import Enum
from pydantic import BaseModel
from onyx.configs.constants import NotificationType
from onyx.configs.constants import QueryHistoryType
from onyx.db.models import Notification as NotificationDBModel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
class PageType(str, Enum):
@@ -49,9 +51,10 @@ class Settings(BaseModel):
temperature_override_enabled: bool | None = False
auto_scroll: bool | None = False
query_history_type: QueryHistoryType | None = None
class UserSettings(Settings):
notifications: list[Notification]
needs_reindexing: bool
tenant_id: str | None = None
tenant_id: str = POSTGRES_DEFAULT_SCHEMA

View File

@@ -1,3 +1,4 @@
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
@@ -45,6 +46,7 @@ def load_settings() -> Settings:
anonymous_user_enabled = False
settings.anonymous_user_enabled = anonymous_user_enabled
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
return settings

View File

@@ -65,7 +65,7 @@ logger = setup_logger()
def setup_onyx(
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
db_session: Session, tenant_id: str, cohere_enabled: bool = False
) -> None:
"""
Setup Onyx for a particular tenant. In the Single Tenant case, it will set it up for the default schema

View File

@@ -260,7 +260,7 @@ def get_documents_for_tenant_connector(
def search_for_document(
index_name: str,
document_id: str | None = None,
tenant_id: str | None = None,
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
max_hits: int | None = 10,
) -> List[Dict[str, Any]]:
yql_query = f"select * from sources {index_name}"
@@ -507,9 +507,9 @@ def get_number_of_chunks_we_think_exist(
class VespaDebugging:
# Class for managing Vespa debugging actions.
def __init__(self, tenant_id: str | None = None):
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
self.tenant_id = tenant_id
self.index_name = get_index_name(self.tenant_id)
def sample_document_counts(self) -> None:
@@ -603,7 +603,7 @@ class VespaDebugging:
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
def search_for_document(
self, document_id: str | None = None, tenant_id: str | None = None
self, document_id: str | None = None, tenant_id: str = POSTGRES_DEFAULT_SCHEMA
) -> List[Dict[str, Any]]:
return search_for_document(self.index_name, document_id, tenant_id)

View File

@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.search_settings import get_active_search_settings
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# Modify sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -74,7 +75,7 @@ def _unsafe_deletion(
for document in documents:
document_index.delete_single(
doc_id=document.id,
tenant_id=None,
tenant_id=POSTGRES_DEFAULT_SCHEMA,
chunk_count=document.chunk_count,
)

View File

@@ -6,6 +6,7 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.document_index.document_index_utils import get_multipass_config
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
# makes it so `PYTHONPATH=.` is not required when running this script
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -96,7 +97,9 @@ def main() -> None:
try:
print(f"Deleting document {doc_id} in Vespa")
chunks_deleted = vespa_index.delete_single(
doc_id, tenant_id=None, chunk_count=document.chunk_count
doc_id,
tenant_id=POSTGRES_DEFAULT_SCHEMA,
chunk_count=document.chunk_count,
)
if chunks_deleted > 0:
print(

View File

@@ -17,6 +17,8 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
def get_current_tenant_id() -> str:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id is None and MULTI_TENANT:
if tenant_id is None:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
raise RuntimeError("Tenant ID is not set. This should never happen.")
return tenant_id

View File

@@ -87,7 +87,7 @@ def test_confluence_connector_basic(
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
assert txt_doc.primary_owners
assert txt_doc.primary_owners[0].email == "chris@danswer.ai"
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
assert (
txt_doc.sections[0].link
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"

View File

@@ -68,6 +68,28 @@ const nextConfig = {
},
];
},
async rewrites() {
return [
{
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs/:path*`,
},
{
source: "/api/docs", // if you also need the exact /api/docs
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs`,
},
{
source: "/openapi.json",
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/openapi.json`,
},
];
},
};
// Sentry configuration for error monitoring:

89
web/package-lock.json generated
View File

@@ -70,6 +70,8 @@
"recharts": "^2.13.1",
"rehype-katex": "^7.0.1",
"rehype-prism-plus": "^2.0.0",
"rehype-sanitize": "^6.0.0",
"rehype-stringify": "^10.0.1",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"semver": "^7.5.4",
@@ -11741,6 +11743,54 @@
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz",
"integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA=="
},
"node_modules/hast-util-sanitize": {
"version": "5.0.2",
"resolved": "https://registry.npmjs.org/hast-util-sanitize/-/hast-util-sanitize-5.0.2.tgz",
"integrity": "sha512-3yTWghByc50aGS7JlGhk61SPenfE/p1oaFeNwkOOyrscaOkMGrcW9+Cy/QAIOBpZxP1yqDIzFMR0+Np0i0+usg==",
"license": "MIT",
"dependencies": {
"@types/hast": "^3.0.0",
"@ungap/structured-clone": "^1.0.0",
"unist-util-position": "^5.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/unified"
}
},
"node_modules/hast-util-to-html": {
"version": "9.0.5",
"resolved": "https://registry.npmjs.org/hast-util-to-html/-/hast-util-to-html-9.0.5.tgz",
"integrity": "sha512-OguPdidb+fbHQSU4Q4ZiLKnzWo8Wwsf5bZfbvu7//a9oTYoqD/fWpe96NuHkoS9h0ccGOTe0C4NGXdtS0iObOw==",
"license": "MIT",
"dependencies": {
"@types/hast": "^3.0.0",
"@types/unist": "^3.0.0",
"ccount": "^2.0.0",
"comma-separated-tokens": "^2.0.0",
"hast-util-whitespace": "^3.0.0",
"html-void-elements": "^3.0.0",
"mdast-util-to-hast": "^13.0.0",
"property-information": "^7.0.0",
"space-separated-tokens": "^2.0.0",
"stringify-entities": "^4.0.0",
"zwitch": "^2.0.4"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/unified"
}
},
"node_modules/hast-util-to-html/node_modules/property-information": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.0.0.tgz",
"integrity": "sha512-7D/qOz/+Y4X/rzSB6jKxKUsQnphO046ei8qxG59mtM3RG3DHgTK81HrxrmoDVINJb8NKT5ZsRbwHvQ6B68Iyhg==",
"license": "MIT",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/hast-util-to-jsx-runtime": {
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.0.tgz",
@@ -11919,6 +11969,16 @@
"url": "https://opencollective.com/unified"
}
},
"node_modules/html-void-elements": {
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
"integrity": "sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==",
"license": "MIT",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/wooorm"
}
},
"node_modules/html-webpack-plugin": {
"version": "5.6.3",
"resolved": "https://registry.npmjs.org/html-webpack-plugin/-/html-webpack-plugin-5.6.3.tgz",
@@ -19125,6 +19185,35 @@
"unist-util-visit": "^5.0.0"
}
},
"node_modules/rehype-sanitize": {
"version": "6.0.0",
"resolved": "https://registry.npmjs.org/rehype-sanitize/-/rehype-sanitize-6.0.0.tgz",
"integrity": "sha512-CsnhKNsyI8Tub6L4sm5ZFsme4puGfc6pYylvXo1AeqaGbjOYyzNv3qZPwvs0oMJ39eryyeOdmxwUIo94IpEhqg==",
"license": "MIT",
"dependencies": {
"@types/hast": "^3.0.0",
"hast-util-sanitize": "^5.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/unified"
}
},
"node_modules/rehype-stringify": {
"version": "10.0.1",
"resolved": "https://registry.npmjs.org/rehype-stringify/-/rehype-stringify-10.0.1.tgz",
"integrity": "sha512-k9ecfXHmIPuFVI61B9DeLPN0qFHfawM6RsuX48hoqlaKSF61RskNjSm1lI8PhBEM0MRdLxVVm4WmTqJQccH9mA==",
"license": "MIT",
"dependencies": {
"@types/hast": "^3.0.0",
"hast-util-to-html": "^9.0.0",
"unified": "^11.0.0"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/unified"
}
},
"node_modules/relateurl": {
"version": "0.2.7",
"resolved": "https://registry.npmjs.org/relateurl/-/relateurl-0.2.7.tgz",

View File

@@ -73,6 +73,8 @@
"recharts": "^2.13.1",
"rehype-katex": "^7.0.1",
"rehype-prism-plus": "^2.0.0",
"rehype-sanitize": "^6.0.0",
"rehype-stringify": "^10.0.1",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"semver": "^7.5.4",

View File

@@ -1,6 +1,6 @@
import { Button } from "@/components/Button";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { useState } from "react";
import React, { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
import * as Yup from "yup";
import { useRouter } from "next/navigation";
@@ -17,13 +17,18 @@ import {
GoogleDriveCredentialJson,
GoogleDriveServiceAccountCredentialJson,
} from "@/lib/connectors/credentials";
import { refreshAllGoogleData } from "@/lib/googleConnector";
import { ValidSources } from "@/lib/types";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account";
export const DriveJsonUpload = ({
setPopup,
onSuccess,
}: {
setPopup: (popupSpec: PopupSpec | null) => void;
onSuccess?: () => void;
}) => {
const { mutate } = useSWRConfig();
const [credentialJsonStr, setCredentialJsonStr] = useState<
@@ -62,7 +67,6 @@ export const DriveJsonUpload = ({
<Button
disabled={!credentialJsonStr}
onClick={async () => {
// check if the JSON is a app credential or a service account credential
let credentialFileType: GoogleDriveCredentialJsonTypes;
try {
const appCredentialJson = JSON.parse(credentialJsonStr!);
@@ -99,6 +103,10 @@ export const DriveJsonUpload = ({
message: "Successfully uploaded app credentials",
type: "success",
});
mutate("/api/manage/admin/connector/google-drive/app-credential");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
@@ -106,7 +114,6 @@ export const DriveJsonUpload = ({
type: "error",
});
}
mutate("/api/manage/admin/connector/google-drive/app-credential");
}
if (credentialFileType === "service_account") {
@@ -122,19 +129,22 @@ export const DriveJsonUpload = ({
);
if (response.ok) {
setPopup({
message: "Successfully uploaded app credentials",
message: "Successfully uploaded service account key",
type: "success",
});
mutate(
"/api/manage/admin/connector/google-drive/service-account-key"
);
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload app credentials - ${errorMsg}`,
message: `Failed to upload service account key - ${errorMsg}`,
type: "error",
});
}
mutate(
"/api/manage/admin/connector/google-drive/service-account-key"
);
}
}}
>
@@ -149,6 +159,7 @@ interface DriveJsonUploadSectionProps {
appCredentialData?: { client_id: string };
serviceAccountCredentialData?: { service_account_email: string };
isAdmin: boolean;
onSuccess?: () => void;
}
export const DriveJsonUploadSection = ({
@@ -156,17 +167,36 @@ export const DriveJsonUploadSection = ({
appCredentialData,
serviceAccountCredentialData,
isAdmin,
onSuccess,
}: DriveJsonUploadSectionProps) => {
const { mutate } = useSWRConfig();
const router = useRouter();
const [localServiceAccountData, setLocalServiceAccountData] = useState(
serviceAccountCredentialData
);
const [localAppCredentialData, setLocalAppCredentialData] =
useState(appCredentialData);
if (serviceAccountCredentialData?.service_account_email) {
useEffect(() => {
setLocalServiceAccountData(serviceAccountCredentialData);
setLocalAppCredentialData(appCredentialData);
}, [serviceAccountCredentialData, appCredentialData]);
const handleSuccess = () => {
if (onSuccess) {
onSuccess();
} else {
refreshAllGoogleData(ValidSources.GoogleDrive);
}
};
if (localServiceAccountData?.service_account_email) {
return (
<div className="mt-2 text-sm">
<div>
Found existing service account key with the following <b>Email:</b>
<p className="italic mt-1">
{serviceAccountCredentialData.service_account_email}
{localServiceAccountData.service_account_email}
</p>
</div>
{isAdmin ? (
@@ -188,11 +218,15 @@ export const DriveJsonUploadSection = ({
mutate(
"/api/manage/admin/connector/google-drive/service-account-key"
);
mutate(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
);
setPopup({
message: "Successfully deleted service account key",
type: "success",
});
router.refresh();
setLocalServiceAccountData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
@@ -216,12 +250,12 @@ export const DriveJsonUploadSection = ({
);
}
if (appCredentialData?.client_id) {
if (localAppCredentialData?.client_id) {
return (
<div className="mt-2 text-sm">
<div>
Found existing app credentials with the following <b>Client ID:</b>
<p className="italic mt-1">{appCredentialData.client_id}</p>
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
</div>
{isAdmin ? (
<>
@@ -242,10 +276,15 @@ export const DriveJsonUploadSection = ({
mutate(
"/api/manage/admin/connector/google-drive/app-credential"
);
mutate(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive)
);
setPopup({
message: "Successfully deleted app credentials",
type: "success",
});
setLocalAppCredentialData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
@@ -297,7 +336,7 @@ export const DriveJsonUploadSection = ({
Download the credentials JSON if choosing option (1) or the Service
Account key JSON if chooosing option (2), and upload it here.
</p>
<DriveJsonUpload setPopup={setPopup} />
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
</div>
);
};
@@ -348,13 +387,41 @@ export const DriveAuthSection = ({
appCredentialData,
setPopup,
refreshCredentials,
connectorAssociated, // don't allow revoke if a connector / credential pair is active with the uploaded credential
connectorAssociated,
user,
}: DriveCredentialSectionProps) => {
const router = useRouter();
const [localServiceAccountData, setLocalServiceAccountData] = useState(
serviceAccountKeyData
);
const [localAppCredentialData, setLocalAppCredentialData] =
useState(appCredentialData);
const [
localGoogleDrivePublicCredential,
setLocalGoogleDrivePublicCredential,
] = useState(googleDrivePublicUploadedCredential);
const [
localGoogleDriveServiceAccountCredential,
setLocalGoogleDriveServiceAccountCredential,
] = useState(googleDriveServiceAccountCredential);
useEffect(() => {
setLocalServiceAccountData(serviceAccountKeyData);
setLocalAppCredentialData(appCredentialData);
setLocalGoogleDrivePublicCredential(googleDrivePublicUploadedCredential);
setLocalGoogleDriveServiceAccountCredential(
googleDriveServiceAccountCredential
);
}, [
serviceAccountKeyData,
appCredentialData,
googleDrivePublicUploadedCredential,
googleDriveServiceAccountCredential,
]);
const existingCredential =
googleDrivePublicUploadedCredential || googleDriveServiceAccountCredential;
localGoogleDrivePublicCredential ||
localGoogleDriveServiceAccountCredential;
if (existingCredential) {
return (
<>
@@ -377,7 +444,7 @@ export const DriveAuthSection = ({
);
}
if (serviceAccountKeyData?.service_account_email) {
if (localServiceAccountData?.service_account_email) {
return (
<div>
<Formik
@@ -438,7 +505,7 @@ export const DriveAuthSection = ({
);
}
if (appCredentialData?.client_id) {
if (localAppCredentialData?.client_id) {
return (
<div className="text-sm mb-4">
<p className="mb-2">

View File

@@ -1,8 +1,7 @@
"use client";
import React, { useEffect, useState } from "react";
import useSWR, { mutate } from "swr";
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
import React from "react";
import { FetchError } from "@/lib/fetcher";
import { ErrorCallout } from "@/components/ErrorCallout";
import { LoadingAnimation } from "@/components/Loading";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
@@ -15,22 +14,17 @@ import {
GoogleDriveCredentialJson,
GoogleDriveServiceAccountCredentialJson,
} from "@/lib/connectors/credentials";
import { ConnectorSnapshot } from "@/lib/connectors/connectors";
import { useUser } from "@/components/user/UserProvider";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
const useConnectorsByCredentialId = (credential_id: number | null) => {
let url: string | null = null;
if (credential_id !== null) {
url = `/api/manage/admin/connector?credential=${credential_id}`;
}
const swrResponse = useSWR<ConnectorSnapshot[]>(url, errorHandlingFetcher);
return {
...swrResponse,
refreshConnectorsByCredentialId: () => mutate(url),
};
};
import {
useGoogleAppCredential,
useGoogleServiceAccountKey,
useGoogleCredentials,
useConnectorsByCredentialId,
checkCredentialsFetched,
filterUploadedCredentials,
checkConnectorsExist,
refreshAllGoogleData,
} from "@/lib/googleConnector";
const GDriveMain = ({
setPopup,
@@ -39,27 +33,20 @@ const GDriveMain = ({
}) => {
const { isAdmin, user } = useUser();
// tries getting the uploaded credential json
// Get app credential and service account key
const {
data: appCredentialData,
isLoading: isAppCredentialLoading,
error: isAppCredentialError,
} = useSWR<{ client_id: string }, FetchError>(
"/api/manage/admin/connector/google-drive/app-credential",
errorHandlingFetcher
);
} = useGoogleAppCredential("google_drive");
// tries getting the uploaded service account key
const {
data: serviceAccountKeyData,
isLoading: isServiceAccountKeyLoading,
error: isServiceAccountKeyError,
} = useSWR<{ service_account_email: string }, FetchError>(
"/api/manage/admin/connector/google-drive/service-account-key",
errorHandlingFetcher
);
} = useGoogleServiceAccountKey("google_drive");
// gets all public credentials
// Get all public credentials
const {
data: credentialsData,
isLoading: isCredentialsLoading,
@@ -67,33 +54,19 @@ const GDriveMain = ({
refreshCredentials,
} = usePublicCredentials();
// gets all credentials for source type google drive
// Get Google Drive-specific credentials
const {
data: googleDriveCredentials,
isLoading: isGoogleDriveCredentialsLoading,
error: googleDriveCredentialsError,
} = useSWR<Credential<any>[]>(
buildSimilarCredentialInfoURL(ValidSources.GoogleDrive),
errorHandlingFetcher,
{ refreshInterval: 5000 }
} = useGoogleCredentials(ValidSources.GoogleDrive);
// Filter uploaded credentials and get credential ID
const { credential_id, uploadedCredentials } = filterUploadedCredentials(
googleDriveCredentials
);
// filters down to just credentials that were created via upload (there should be only one)
let credential_id = null;
if (googleDriveCredentials) {
const googleDriveUploadedCredentials: Credential<GoogleDriveCredentialJson>[] =
googleDriveCredentials.filter(
(googleDriveCredential) =>
googleDriveCredential.credential_json.authentication_method !==
"oauth_interactive"
);
if (googleDriveUploadedCredentials.length > 0) {
credential_id = googleDriveUploadedCredentials[0].id;
}
}
// retrieves all connectors for that credential id
// Get connectors for the credential ID
const {
data: googleDriveConnectors,
isLoading: isGoogleDriveConnectorsLoading,
@@ -101,13 +74,25 @@ const GDriveMain = ({
refreshConnectorsByCredentialId,
} = useConnectorsByCredentialId(credential_id);
const appCredentialSuccessfullyFetched =
appCredentialData ||
(isAppCredentialError && isAppCredentialError.status === 404);
const serviceAccountKeySuccessfullyFetched =
serviceAccountKeyData ||
(isServiceAccountKeyError && isServiceAccountKeyError.status === 404);
// Check if credentials were successfully fetched
const {
appCredentialSuccessfullyFetched,
serviceAccountKeySuccessfullyFetched,
} = checkCredentialsFetched(
appCredentialData,
isAppCredentialError,
serviceAccountKeyData,
isServiceAccountKeyError
);
// Handle refresh of all data
const handleRefresh = () => {
refreshCredentials();
refreshConnectorsByCredentialId();
refreshAllGoogleData(ValidSources.GoogleDrive);
};
// Loading state
if (
(!appCredentialSuccessfullyFetched && isAppCredentialLoading) ||
(!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) ||
@@ -122,6 +107,7 @@ const GDriveMain = ({
);
}
// Error states
if (credentialsError || !credentialsData) {
return <ErrorCallout errorTitle="Failed to load credentials." />;
}
@@ -141,7 +127,16 @@ const GDriveMain = ({
);
}
// get the actual uploaded oauth or service account credentials
if (googleDriveConnectorsError) {
return (
<ErrorCallout errorTitle="Failed to load Google Drive associated connectors." />
);
}
// Check if connectors exist
const connectorAssociated = checkConnectorsExist(googleDriveConnectors);
// Get the uploaded OAuth credential
const googleDrivePublicUploadedCredential:
| Credential<GoogleDriveCredentialJson>
| undefined = credentialsData.find(
@@ -152,6 +147,7 @@ const GDriveMain = ({
credential.credential_json.authentication_method !== "oauth_interactive"
);
// Get the service account credential
const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson>
| undefined = credentialsData.find(
@@ -160,19 +156,6 @@ const GDriveMain = ({
credential.source === "google_drive"
);
if (googleDriveConnectorsError) {
return (
<ErrorCallout errorTitle="Failed to load Google Drive associated connectors." />
);
}
let connectorAssociated = false;
if (googleDriveConnectors) {
if (googleDriveConnectors.length > 0) {
connectorAssociated = true;
}
}
return (
<>
<Title className="mb-2 mt-6">Step 1: Provide your Credentials</Title>
@@ -181,27 +164,30 @@ const GDriveMain = ({
appCredentialData={appCredentialData}
serviceAccountCredentialData={serviceAccountKeyData}
isAdmin={isAdmin}
onSuccess={handleRefresh}
/>
{isAdmin && (
<>
<Title className="mb-2 mt-6">Step 2: Authenticate with Onyx</Title>
<DriveAuthSection
setPopup={setPopup}
refreshCredentials={refreshCredentials}
googleDrivePublicUploadedCredential={
googleDrivePublicUploadedCredential
}
googleDriveServiceAccountCredential={
googleDriveServiceAccountCredential
}
appCredentialData={appCredentialData}
serviceAccountKeyData={serviceAccountKeyData}
connectorAssociated={connectorAssociated}
user={user}
/>
</>
)}
{isAdmin &&
(appCredentialData?.client_id ||
serviceAccountKeyData?.service_account_email) && (
<>
<Title className="mb-2 mt-6">Step 2: Authenticate with Onyx</Title>
<DriveAuthSection
setPopup={setPopup}
refreshCredentials={handleRefresh}
googleDrivePublicUploadedCredential={
googleDrivePublicUploadedCredential
}
googleDriveServiceAccountCredential={
googleDriveServiceAccountCredential
}
appCredentialData={appCredentialData}
serviceAccountKeyData={serviceAccountKeyData}
connectorAssociated={connectorAssociated}
user={user}
/>
</>
)}
</>
);
};

View File

@@ -1,6 +1,6 @@
import { Button } from "@/components/Button";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { useState } from "react";
import React, { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
import * as Yup from "yup";
import { useRouter } from "next/navigation";
@@ -17,13 +17,18 @@ import {
GmailCredentialJson,
GmailServiceAccountCredentialJson,
} from "@/lib/connectors/credentials";
import { refreshAllGoogleData } from "@/lib/googleConnector";
import { ValidSources } from "@/lib/types";
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
type GmailCredentialJsonTypes = "authorized_user" | "service_account";
const DriveJsonUpload = ({
setPopup,
onSuccess,
}: {
setPopup: (popupSpec: PopupSpec | null) => void;
onSuccess?: () => void;
}) => {
const { mutate } = useSWRConfig();
const [credentialJsonStr, setCredentialJsonStr] = useState<
@@ -72,7 +77,7 @@ const DriveJsonUpload = ({
credentialFileType = "service_account";
} else {
throw new Error(
"Unknown credential type, expected 'OAuth Web application'"
"Unknown credential type, expected one of 'OAuth Web application' or 'Service Account'"
);
}
} catch (e) {
@@ -99,6 +104,10 @@ const DriveJsonUpload = ({
message: "Successfully uploaded app credentials",
type: "success",
});
mutate("/api/manage/admin/connector/gmail/app-credential");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
@@ -106,7 +115,6 @@ const DriveJsonUpload = ({
type: "error",
});
}
mutate("/api/manage/admin/connector/gmail/app-credential");
}
if (credentialFileType === "service_account") {
@@ -122,17 +130,20 @@ const DriveJsonUpload = ({
);
if (response.ok) {
setPopup({
message: "Successfully uploaded app credentials",
message: "Successfully uploaded service account key",
type: "success",
});
mutate("/api/manage/admin/connector/gmail/service-account-key");
if (onSuccess) {
onSuccess();
}
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to upload app credentials - ${errorMsg}`,
message: `Failed to upload service account key - ${errorMsg}`,
type: "error",
});
}
mutate("/api/manage/admin/connector/gmail/service-account-key");
}
}}
>
@@ -147,6 +158,7 @@ interface DriveJsonUploadSectionProps {
appCredentialData?: { client_id: string };
serviceAccountCredentialData?: { service_account_email: string };
isAdmin: boolean;
onSuccess?: () => void;
}
export const GmailJsonUploadSection = ({
@@ -154,16 +166,37 @@ export const GmailJsonUploadSection = ({
appCredentialData,
serviceAccountCredentialData,
isAdmin,
onSuccess,
}: DriveJsonUploadSectionProps) => {
const { mutate } = useSWRConfig();
const router = useRouter();
const [localServiceAccountData, setLocalServiceAccountData] = useState(
serviceAccountCredentialData
);
const [localAppCredentialData, setLocalAppCredentialData] =
useState(appCredentialData);
if (serviceAccountCredentialData?.service_account_email) {
// Update local state when props change
useEffect(() => {
setLocalServiceAccountData(serviceAccountCredentialData);
setLocalAppCredentialData(appCredentialData);
}, [serviceAccountCredentialData, appCredentialData]);
const handleSuccess = () => {
if (onSuccess) {
onSuccess();
} else {
refreshAllGoogleData(ValidSources.Gmail);
}
};
if (localServiceAccountData?.service_account_email) {
return (
<div className="mt-2 text-sm">
<div>
Found existing service account key with the following <b>Email:</b>
<p className="italic mt-1">
{serviceAccountCredentialData.service_account_email}
{localServiceAccountData.service_account_email}
</p>
</div>
{isAdmin ? (
@@ -185,10 +218,15 @@ export const GmailJsonUploadSection = ({
mutate(
"/api/manage/admin/connector/gmail/service-account-key"
);
// Also mutate the credential endpoints to ensure Step 2 is reset
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
setPopup({
message: "Successfully deleted service account key",
type: "success",
});
// Immediately update local state
setLocalServiceAccountData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
@@ -212,43 +250,56 @@ export const GmailJsonUploadSection = ({
);
}
if (appCredentialData?.client_id) {
if (localAppCredentialData?.client_id) {
return (
<div className="mt-2 text-sm">
<div>
Found existing app credentials with the following <b>Client ID:</b>
<p className="italic mt-1">{appCredentialData.client_id}</p>
<p className="italic mt-1">{localAppCredentialData.client_id}</p>
</div>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/gmail/app-credential",
{
method: "DELETE",
}
);
if (response.ok) {
mutate("/api/manage/admin/connector/gmail/app-credential");
setPopup({
message: "Successfully deleted service account key",
type: "success",
});
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete app credential - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
{isAdmin ? (
<>
<div className="mt-4 mb-1">
If you want to update these credentials, delete the existing
credentials through the button below, and then upload a new
credentials JSON.
</div>
<Button
onClick={async () => {
const response = await fetch(
"/api/manage/admin/connector/gmail/app-credential",
{
method: "DELETE",
}
);
if (response.ok) {
mutate("/api/manage/admin/connector/gmail/app-credential");
// Also mutate the credential endpoints to ensure Step 2 is reset
mutate(buildSimilarCredentialInfoURL(ValidSources.Gmail));
setPopup({
message: "Successfully deleted app credentials",
type: "success",
});
// Immediately update local state
setLocalAppCredentialData(undefined);
handleSuccess();
} else {
const errorMsg = await response.text();
setPopup({
message: `Failed to delete app credential - ${errorMsg}`,
type: "error",
});
}
}}
>
Delete
</Button>
</>
) : (
<div className="mt-4 mb-1">
To change these credentials, please contact an administrator.
</div>
)}
</div>
);
}
@@ -276,14 +327,14 @@ export const GmailJsonUploadSection = ({
>
here
</a>{" "}
to either (1) setup a google OAuth App in your company workspace or (2)
to either (1) setup a Google OAuth App in your company workspace or (2)
create a Service Account.
<br />
<br />
Download the credentials JSON if choosing option (1) or the Service
Account key JSON if chooosing option (2), and upload it here.
Account key JSON if choosing option (2), and upload it here.
</p>
<DriveJsonUpload setPopup={setPopup} />
<DriveJsonUpload setPopup={setPopup} onSuccess={handleSuccess} />
</div>
);
};
@@ -299,6 +350,34 @@ interface DriveCredentialSectionProps {
user: User | null;
}
async function handleRevokeAccess(
connectorExists: boolean,
setPopup: (popupSpec: PopupSpec | null) => void,
existingCredential:
| Credential<GmailCredentialJson>
| Credential<GmailServiceAccountCredentialJson>,
refreshCredentials: () => void
) {
if (connectorExists) {
const message =
"Cannot revoke the Gmail credential while any connector is still associated with the credential. " +
"Please delete all associated connectors, then try again.";
setPopup({
message: message,
type: "error",
});
return;
}
await adminDeleteCredential(existingCredential.id);
setPopup({
message: "Successfully revoked the Gmail credential!",
type: "success",
});
refreshCredentials();
}
export const GmailAuthSection = ({
gmailPublicCredential,
gmailServiceAccountCredential,
@@ -310,31 +389,49 @@ export const GmailAuthSection = ({
user,
}: DriveCredentialSectionProps) => {
const router = useRouter();
const [isAuthenticating, setIsAuthenticating] = useState(false);
const [localServiceAccountData, setLocalServiceAccountData] = useState(
serviceAccountKeyData
);
const [localAppCredentialData, setLocalAppCredentialData] =
useState(appCredentialData);
const [localGmailPublicCredential, setLocalGmailPublicCredential] = useState(
gmailPublicCredential
);
const [
localGmailServiceAccountCredential,
setLocalGmailServiceAccountCredential,
] = useState(gmailServiceAccountCredential);
// Update local state when props change
useEffect(() => {
setLocalServiceAccountData(serviceAccountKeyData);
setLocalAppCredentialData(appCredentialData);
setLocalGmailPublicCredential(gmailPublicCredential);
setLocalGmailServiceAccountCredential(gmailServiceAccountCredential);
}, [
serviceAccountKeyData,
appCredentialData,
gmailPublicCredential,
gmailServiceAccountCredential,
]);
const existingCredential =
gmailPublicCredential || gmailServiceAccountCredential;
localGmailPublicCredential || localGmailServiceAccountCredential;
if (existingCredential) {
return (
<>
<p className="mb-2 text-sm">
<i>Existing credential already set up!</i>
<i>Uploaded and authenticated credential already exists!</i>
</p>
<Button
onClick={async () => {
if (connectorExists) {
setPopup({
message:
"Cannot revoke access to Gmail while any connector is still set up. Please delete all connectors, then try again.",
type: "error",
});
return;
}
await adminDeleteCredential(existingCredential.id);
setPopup({
message: "Successfully revoked access to Gmail!",
type: "success",
});
refreshCredentials();
handleRevokeAccess(
connectorExists,
setPopup,
existingCredential,
refreshCredentials
);
}}
>
Revoke Access
@@ -343,20 +440,21 @@ export const GmailAuthSection = ({
);
}
if (serviceAccountKeyData?.service_account_email) {
if (localServiceAccountData?.service_account_email) {
return (
<div>
<CardSection>
<Formik
initialValues={{
google_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_primary_admin: Yup.string().required(),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
<Formik
initialValues={{
google_primary_admin: user?.email || "",
}}
validationSchema={Yup.object().shape({
google_primary_admin: Yup.string()
.email("Must be a valid email")
.required("Required"),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
const response = await fetch(
"/api/manage/admin/connector/gmail/service-account-credential",
{
@@ -375,6 +473,7 @@ export const GmailAuthSection = ({
message: "Successfully created service account credential",
type: "success",
});
refreshCredentials();
} else {
const errorMsg = await response.text();
setPopup({
@@ -382,65 +481,73 @@ export const GmailAuthSection = ({
type: "error",
});
}
refreshCredentials();
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_primary_admin"
label="Primary Admin Email:"
subtext="You must provide an admin/owner account to retrieve all org emails."
/>
<div className="flex">
<button
type="submit"
disabled={isSubmitting}
className={
"bg-slate-500 hover:bg-slate-700 text-white " +
"font-bold py-2 px-4 rounded focus:outline-none " +
"focus:shadow-outline w-full max-w-sm mx-auto"
}
>
Submit
</button>
</div>
</Form>
)}
</Formik>
</CardSection>
} catch (error) {
setPopup({
message: `Failed to create service account credential - ${error}`,
type: "error",
});
} finally {
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting }) => (
<Form>
<TextFormField
name="google_primary_admin"
label="Primary Admin Email:"
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
/>
<div className="flex">
<Button type="submit" disabled={isSubmitting}>
Create Credential
</Button>
</div>
</Form>
)}
</Formik>
</div>
);
}
if (appCredentialData?.client_id) {
if (localAppCredentialData?.client_id) {
return (
<div className="text-sm mb-4">
<p className="mb-2">
Next, you must provide credentials via OAuth. This gives us read
access to the docs you have access to in your gmail account.
access to the emails you have access to in your Gmail account.
</p>
<Button
onClick={async () => {
const [authUrl, errorMsg] = await setupGmailOAuth({
isAdmin: true,
});
if (authUrl) {
// cookie used by callback to determine where to finally redirect to
setIsAuthenticating(true);
try {
Cookies.set(GMAIL_AUTH_IS_ADMIN_COOKIE_NAME, "true", {
path: "/",
});
router.push(authUrl);
return;
}
const [authUrl, errorMsg] = await setupGmailOAuth({
isAdmin: true,
});
setPopup({
message: errorMsg,
type: "error",
});
if (authUrl) {
router.push(authUrl);
} else {
setPopup({
message: errorMsg,
type: "error",
});
setIsAuthenticating(false);
}
} catch (error) {
setPopup({
message: `Failed to authenticate with Gmail - ${error}`,
type: "error",
});
setIsAuthenticating(false);
}
}}
disabled={isAuthenticating}
>
Authenticate with Gmail
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
</Button>
</div>
);
@@ -449,8 +556,8 @@ export const GmailAuthSection = ({
// case where no keys have been uploaded in step 1
return (
<p className="text-sm">
Please upload an OAuth or Service Account Credential JSON in Step 1 before
moving onto Step 2.
Please upload either a OAuth Client Credential JSON or a Gmail Service
Account Key JSON in Step 1 before moving onto Step 2.
</p>
);
};

View File

@@ -1,10 +1,11 @@
"use client";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import React from "react";
import { FetchError } from "@/lib/fetcher";
import { ErrorCallout } from "@/components/ErrorCallout";
import { LoadingAnimation } from "@/components/Loading";
import { usePopup } from "@/components/admin/connectors/Popup";
import { CCPairBasicInfo } from "@/lib/types";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { CCPairBasicInfo, ValidSources } from "@/lib/types";
import {
Credential,
GmailCredentialJson,
@@ -14,26 +15,33 @@ import { GmailAuthSection, GmailJsonUploadSection } from "./Credential";
import { usePublicCredentials, useBasicConnectorStatus } from "@/lib/hooks";
import Title from "@/components/ui/title";
import { useUser } from "@/components/user/UserProvider";
import {
useGoogleAppCredential,
useGoogleServiceAccountKey,
useGoogleCredentials,
useConnectorsByCredentialId,
checkCredentialsFetched,
filterUploadedCredentials,
checkConnectorsExist,
refreshAllGoogleData,
} from "@/lib/googleConnector";
export const GmailMain = () => {
const { isAdmin, user } = useUser();
const { popup, setPopup } = usePopup();
const {
data: appCredentialData,
isLoading: isAppCredentialLoading,
error: isAppCredentialError,
} = useSWR<{ client_id: string }>(
"/api/manage/admin/connector/gmail/app-credential",
errorHandlingFetcher
);
} = useGoogleAppCredential("gmail");
const {
data: serviceAccountKeyData,
isLoading: isServiceAccountKeyLoading,
error: isServiceAccountKeyError,
} = useSWR<{ service_account_email: string }>(
"/api/manage/admin/connector/gmail/service-account-key",
errorHandlingFetcher
);
} = useGoogleServiceAccountKey("gmail");
const {
data: connectorIndexingStatuses,
isLoading: isConnectorIndexingStatusesLoading,
@@ -47,20 +55,45 @@ export const GmailMain = () => {
refreshCredentials,
} = usePublicCredentials();
const { popup, setPopup } = usePopup();
const {
data: gmailCredentials,
isLoading: isGmailCredentialsLoading,
error: gmailCredentialsError,
} = useGoogleCredentials(ValidSources.Gmail);
const appCredentialSuccessfullyFetched =
appCredentialData ||
(isAppCredentialError && isAppCredentialError.status === 404);
const serviceAccountKeySuccessfullyFetched =
serviceAccountKeyData ||
(isServiceAccountKeyError && isServiceAccountKeyError.status === 404);
const { credential_id, uploadedCredentials } =
filterUploadedCredentials(gmailCredentials);
const {
data: gmailConnectors,
isLoading: isGmailConnectorsLoading,
error: gmailConnectorsError,
refreshConnectorsByCredentialId,
} = useConnectorsByCredentialId(credential_id);
const {
appCredentialSuccessfullyFetched,
serviceAccountKeySuccessfullyFetched,
} = checkCredentialsFetched(
appCredentialData,
isAppCredentialError,
serviceAccountKeyData,
isServiceAccountKeyError
);
const handleRefresh = () => {
refreshCredentials();
refreshConnectorsByCredentialId();
refreshAllGoogleData(ValidSources.Gmail);
};
if (
(!appCredentialSuccessfullyFetched && isAppCredentialLoading) ||
(!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) ||
(!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) ||
(!credentialsData && isCredentialsLoading)
(!credentialsData && isCredentialsLoading) ||
(!gmailCredentials && isGmailCredentialsLoading) ||
(!gmailConnectors && isGmailConnectorsLoading)
) {
return (
<div className="mx-auto">
@@ -70,19 +103,15 @@ export const GmailMain = () => {
}
if (credentialsError || !credentialsData) {
return (
<div className="mx-auto">
<div className="text-red-500">Failed to load credentials.</div>
</div>
);
return <ErrorCallout errorTitle="Failed to load credentials." />;
}
if (gmailCredentialsError || !gmailCredentials) {
return <ErrorCallout errorTitle="Failed to load Gmail credentials." />;
}
if (connectorIndexingStatusesError || !connectorIndexingStatuses) {
return (
<div className="mx-auto">
<div className="text-red-500">Failed to load connectors.</div>
</div>
);
return <ErrorCallout errorTitle="Failed to load connectors." />;
}
if (
@@ -90,21 +119,28 @@ export const GmailMain = () => {
!serviceAccountKeySuccessfullyFetched
) {
return (
<div className="mx-auto">
<div className="text-red-500">
Error loading Gmail app credentials. Contact an administrator.
</div>
</div>
<ErrorCallout errorTitle="Error loading Gmail app credentials. Contact an administrator." />
);
}
const gmailPublicCredential: Credential<GmailCredentialJson> | undefined =
credentialsData.find(
(credential) =>
(credential.credential_json?.google_service_account_key ||
credential.credential_json?.google_tokens) &&
credential.admin_public
if (gmailConnectorsError) {
return (
<ErrorCallout errorTitle="Failed to load Gmail associated connectors." />
);
}
const connectorExistsFromCredential = checkConnectorsExist(gmailConnectors);
const gmailPublicUploadedCredential:
| Credential<GmailCredentialJson>
| undefined = credentialsData.find(
(credential) =>
credential.credential_json?.google_tokens &&
credential.admin_public &&
credential.source === "gmail" &&
credential.credential_json.authentication_method !== "oauth_interactive"
);
const gmailServiceAccountCredential:
| Credential<GmailServiceAccountCredentialJson>
| undefined = credentialsData.find(
@@ -118,6 +154,13 @@ export const GmailMain = () => {
(connectorIndexingStatus) => connectorIndexingStatus.source === "gmail"
);
const connectorExists =
connectorExistsFromCredential || gmailConnectorIndexingStatuses.length > 0;
const hasUploadedCredentials =
Boolean(appCredentialData?.client_id) ||
Boolean(serviceAccountKeyData?.service_account_email);
return (
<>
{popup}
@@ -129,21 +172,22 @@ export const GmailMain = () => {
appCredentialData={appCredentialData}
serviceAccountCredentialData={serviceAccountKeyData}
isAdmin={isAdmin}
onSuccess={handleRefresh}
/>
{isAdmin && (
{isAdmin && hasUploadedCredentials && (
<>
<Title className="mb-2 mt-6 ml-auto mr-auto">
Step 2: Authenticate with Onyx
</Title>
<GmailAuthSection
setPopup={setPopup}
refreshCredentials={refreshCredentials}
gmailPublicCredential={gmailPublicCredential}
refreshCredentials={handleRefresh}
gmailPublicCredential={gmailPublicUploadedCredential}
gmailServiceAccountCredential={gmailServiceAccountCredential}
appCredentialData={appCredentialData}
serviceAccountKeyData={serviceAccountKeyData}
connectorExists={gmailConnectorIndexingStatuses.length > 0}
connectorExists={connectorExists}
user={user}
/>
</>

View File

@@ -4,6 +4,12 @@ export enum ApplicationStatus {
ACTIVE = "active",
}
export enum QueryHistoryType {
DISABLED = "disabled",
ANONYMIZED = "anonymized",
NORMAL = "normal",
}
export interface Settings {
anonymous_user_enabled: boolean;
maximum_chat_retention_days: number | null;
@@ -14,6 +20,7 @@ export interface Settings {
application_status: ApplicationStatus;
auto_scroll: boolean;
temperature_override_enabled: boolean;
query_history_type: QueryHistoryType;
}
export enum NotificationType {

View File

@@ -36,11 +36,14 @@ export function EmailPasswordForm({
{popup}
<Formik
initialValues={{
email: defaultEmail || "",
email: defaultEmail ? defaultEmail.toLowerCase() : "",
password: "",
}}
validationSchema={Yup.object().shape({
email: Yup.string().email().required(),
email: Yup.string()
.email()
.required()
.transform((value) => value.toLowerCase()),
password: Yup.string().required(),
})}
onSubmit={async (values) => {

View File

@@ -56,6 +56,7 @@ import {
Dispatch,
SetStateAction,
use,
useCallback,
useContext,
useEffect,
useLayoutEffect,
@@ -131,17 +132,12 @@ import {
import { getSourceMetadata } from "@/lib/sources";
import { UserSettingsModal } from "./modal/UserSettingsModal";
import { AlignStartVertical } from "lucide-react";
import { AgenticMessage } from "./message/AgenticMessage";
import AssistantModal from "../assistants/mine/AssistantModal";
import {
OperatingSystem,
useOperatingSystem,
useSidebarShortcut,
} from "@/lib/browserUtilities";
import { Button } from "@/components/ui/button";
import { useSidebarShortcut } from "@/lib/browserUtilities";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { MessageChannel } from "node:worker_threads";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
import { ErrorBanner } from "./message/Resubmit";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -870,6 +866,7 @@ export function ChatPage({
}, [liveAssistant]);
const filterManager = useFilters();
const [isChatSearchModalOpen, setIsChatSearchModalOpen] = useState(false);
const [currentFeedback, setCurrentFeedback] = useState<
[FeedbackType, number] | null
@@ -891,24 +888,6 @@ export function ChatPage({
);
const scrollDist = useRef<number>(0);
const updateScrollTracking = () => {
const scrollDistance =
endDivRef?.current?.getBoundingClientRect()?.top! -
inputRef?.current?.getBoundingClientRect()?.top!;
scrollDist.current = scrollDistance;
setAboveHorizon(scrollDist.current > 500);
};
useEffect(() => {
const scrollableDiv = scrollableDivRef.current;
if (scrollableDiv) {
scrollableDiv.addEventListener("scroll", updateScrollTracking);
return () => {
scrollableDiv.removeEventListener("scroll", updateScrollTracking);
};
}
}, []);
const handleInputResize = () => {
setTimeout(() => {
if (
@@ -960,33 +939,12 @@ export function ChatPage({
if (isVisible) return;
// Check if all messages are currently rendered
if (currentVisibleRange.end < messageHistory.length) {
// Update visible range to include the last messages
updateCurrentVisibleRange({
start: Math.max(
0,
messageHistory.length -
(currentVisibleRange.end - currentVisibleRange.start)
),
end: messageHistory.length,
mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
});
// If all messages are already rendered, scroll immediately
endDivRef.current.scrollIntoView({
behavior: fast ? "auto" : "smooth",
});
// Wait for the state update and re-render before scrolling
setTimeout(() => {
endDivRef.current?.scrollIntoView({
behavior: fast ? "auto" : "smooth",
});
setHasPerformedInitialScroll(true);
}, 100);
} else {
// If all messages are already rendered, scroll immediately
endDivRef.current.scrollIntoView({
behavior: fast ? "auto" : "smooth",
});
setHasPerformedInitialScroll(true);
}
setHasPerformedInitialScroll(true);
}, 50);
// Reset waitForScrollRef after 1.5 seconds
@@ -1007,11 +965,6 @@ export function ChatPage({
handleInputResize();
}, [message]);
// tracks scrolling
useEffect(() => {
updateScrollTracking();
}, [messageHistory]);
// used for resizing of the document sidebar
const masterFlexboxRef = useRef<HTMLDivElement>(null);
const [maxDocumentSidebarWidth, setMaxDocumentSidebarWidth] = useState<
@@ -1210,6 +1163,7 @@ export function ChatPage({
navigatingAway.current = false;
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
setUncaughtError(null);
// Mark that we've sent a message for this session in the current page load
markSessionMessageSent(frozenSessionId);
@@ -1360,6 +1314,7 @@ export function ChatPage({
let isStreamingQuestions = true;
let includeAgentic = false;
let secondLevelMessageId: number | null = null;
let isAgentic: boolean = false;
let initialFetchDetails: null | {
user_message_id: number;
@@ -1522,6 +1477,9 @@ export function ChatPage({
second_level_generating = true;
}
}
if (Object.hasOwn(packet, "is_agentic")) {
isAgentic = (packet as any).is_agentic;
}
if (Object.hasOwn(packet, "refined_answer_improvement")) {
isImprovement = (packet as RefinedAnswerImprovement)
@@ -1555,6 +1513,7 @@ export function ChatPage({
);
} else if (Object.hasOwn(packet, "sub_question")) {
updateChatState("toolBuilding", frozenSessionId);
isAgentic = true;
is_generating = true;
sub_questions = constructSubQuestions(
sub_questions,
@@ -1755,6 +1714,7 @@ export function ChatPage({
sub_questions: sub_questions,
second_level_generating: second_level_generating,
agentic_docs: agenticDocs,
is_agentic: isAgentic,
},
...(includeAgentic
? [
@@ -1975,122 +1935,6 @@ export function ChatPage({
// Virtualization + Scrolling related effects and functions
const scrollInitialized = useRef(false);
interface VisibleRange {
start: number;
end: number;
mostVisibleMessageId: number | null;
}
const [visibleRange, setVisibleRange] = useState<
Map<string | null, VisibleRange>
>(() => {
const initialRange: VisibleRange = {
start: 0,
end: BUFFER_COUNT,
mostVisibleMessageId: null,
};
return new Map([[chatSessionIdRef.current, initialRange]]);
});
// Function used to update current visible range. Only method for updating `visibleRange` state.
const updateCurrentVisibleRange = (
newRange: VisibleRange,
forceUpdate?: boolean
) => {
if (
scrollInitialized.current &&
visibleRange.get(loadedIdSessionRef.current) == undefined &&
!forceUpdate
) {
return;
}
setVisibleRange((prevState) => {
const newState = new Map(prevState);
newState.set(loadedIdSessionRef.current, newRange);
return newState;
});
};
// Set first value for visibleRange state on page load / refresh.
const initializeVisibleRange = () => {
const upToDatemessageHistory = buildLatestMessageChain(
currentMessageMap(completeMessageDetail)
);
if (!scrollInitialized.current && upToDatemessageHistory.length > 0) {
const newEnd = Math.max(upToDatemessageHistory.length, BUFFER_COUNT);
const newStart = Math.max(0, newEnd - BUFFER_COUNT);
const newMostVisibleMessageId =
upToDatemessageHistory[newEnd - 1]?.messageId;
updateCurrentVisibleRange(
{
start: newStart,
end: newEnd,
mostVisibleMessageId: newMostVisibleMessageId,
},
true
);
scrollInitialized.current = true;
}
};
const updateVisibleRangeBasedOnScroll = () => {
if (!scrollInitialized.current) return;
const scrollableDiv = scrollableDivRef.current;
if (!scrollableDiv) return;
const viewportHeight = scrollableDiv.clientHeight;
let mostVisibleMessageIndex = -1;
messageHistory.forEach((message, index) => {
const messageElement = document.getElementById(
`message-${message.messageId}`
);
if (messageElement) {
const rect = messageElement.getBoundingClientRect();
const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0;
if (isVisible && index > mostVisibleMessageIndex) {
mostVisibleMessageIndex = index;
}
}
});
if (mostVisibleMessageIndex !== -1) {
const startIndex = Math.max(0, mostVisibleMessageIndex - BUFFER_COUNT);
const endIndex = Math.min(
messageHistory.length,
mostVisibleMessageIndex + BUFFER_COUNT + 1
);
updateCurrentVisibleRange({
start: startIndex,
end: endIndex,
mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId,
});
}
};
useEffect(() => {
initializeVisibleRange();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [router, messageHistory]);
useLayoutEffect(() => {
const scrollableDiv = scrollableDivRef.current;
const handleScroll = () => {
updateVisibleRangeBasedOnScroll();
};
scrollableDiv?.addEventListener("scroll", handleScroll);
return () => {
scrollableDiv?.removeEventListener("scroll", handleScroll);
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [messageHistory]);
const imageFileInMessageHistory = useMemo(() => {
return messageHistory
@@ -2100,11 +1944,6 @@ export function ChatPage({
);
}, [messageHistory]);
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
start: 0,
end: 0,
mostVisibleMessageId: null,
};
useSendMessageToParent();
useEffect(() => {
@@ -2144,6 +1983,15 @@ export function ChatPage({
const currentPersona = alternativeAssistant || liveAssistant;
const HORIZON_DISTANCE = 800;
const handleScroll = useCallback(() => {
const scrollDistance =
endDivRef?.current?.getBoundingClientRect()?.top! -
inputRef?.current?.getBoundingClientRect()?.top!;
scrollDist.current = scrollDistance;
setAboveHorizon(scrollDist.current > HORIZON_DISTANCE);
}, []);
useEffect(() => {
const handleSlackChatRedirect = async () => {
if (!slackChatId) return;
@@ -2215,6 +2063,26 @@ export function ChatPage({
const [sharedChatSession, setSharedChatSession] =
useState<ChatSession | null>();
const handleResubmitLastMessage = () => {
// Grab the last user-type message
const lastUserMsg = messageHistory
.slice()
.reverse()
.find((m) => m.type === "user");
if (!lastUserMsg) {
setPopup({
message: "No previously-submitted user message found.",
type: "error",
});
return;
}
// We call onSubmit, passing a `messageOverride`
onSubmit({
messageIdToResend: lastUserMsg.messageId,
messageOverride: lastUserMsg.message,
});
};
const showShareModal = (chatSession: ChatSession) => {
setSharedChatSession(chatSession);
};
@@ -2329,6 +2197,11 @@ export function ChatPage({
/>
)}
<ChatSearchModal
open={isChatSearchModalOpen}
onCloseModal={() => setIsChatSearchModalOpen(false)}
/>
{retrievalEnabled && documentSidebarVisible && settings?.isMobile && (
<div className="md:hidden">
<Modal
@@ -2436,6 +2309,9 @@ export function ChatPage({
>
<div className="w-full relative">
<HistorySidebar
toggleChatSessionSearchModal={() =>
setIsChatSearchModalOpen((open) => !open)
}
liveAssistant={liveAssistant}
setShowAssistantsModal={setShowAssistantsModal}
explicitlyUntoggle={explicitlyUntoggle}
@@ -2452,6 +2328,7 @@ export function ChatPage({
showDeleteAllModal={() => setShowDeleteAllModal(true)}
/>
</div>
<div
className={`
flex-none
@@ -2585,6 +2462,7 @@ export function ChatPage({
{...getRootProps()}
>
<div
onScroll={handleScroll}
className={`w-full h-[calc(100vh-160px)] flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
ref={scrollableDivRef}
>
@@ -2642,18 +2520,7 @@ export function ChatPage({
// NOTE: temporarily removing this to fix the scroll bug
// (hasPerformedInitialScroll ? "" : "invisible")
>
{(messageHistory.length < BUFFER_COUNT
? messageHistory
: messageHistory.slice(
currentVisibleRange.start,
currentVisibleRange.end
)
).map((message, fauxIndex) => {
const i =
messageHistory.length < BUFFER_COUNT
? fauxIndex
: fauxIndex + currentVisibleRange.start;
{messageHistory.map((message, i) => {
const messageMap = currentMessageMap(
completeMessageDetail
);
@@ -2798,9 +2665,9 @@ export function ChatPage({
: null
}
>
{message.sub_questions &&
message.sub_questions.length > 0 ? (
{message.is_agentic ? (
<AgenticMessage
resubmit={handleResubmitLastMessage}
error={uncaughtError}
isStreamingQuestions={
message.isStreamingQuestions ?? false
@@ -3148,21 +3015,18 @@ export function ChatPage({
currentPersona={liveAssistant}
messageId={message.messageId}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
{message.stackTrace && (
<span
onClick={() =>
setStackTraceModalContent(
message.stackTrace!
)
}
className="ml-2 cursor-pointer underline"
>
Show stack trace.
</span>
)}
</p>
<ErrorBanner
resubmit={handleResubmitLastMessage}
error={message.message}
showStackTrace={
message.stackTrace
? () =>
setStackTraceModalContent(
message.stackTrace!
)
: undefined
}
/>
}
/>
</div>

Some files were not shown because too many files have changed in this diff Show More