mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 18:25:45 +00:00
Compare commits
5 Commits
server_sid
...
fix_edge_c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b8f886d99 | ||
|
|
f54e406234 | ||
|
|
3c1228b131 | ||
|
|
2cb66ac115 | ||
|
|
a3f229669f |
12
.github/workflows/nightly-scan-licenses.yml
vendored
12
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -53,26 +53,22 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
if: ${{ always() }}
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
if: always()
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scan-ref: .
|
||||
scanners: license
|
||||
format: table
|
||||
severity: HIGH,CRITICAL
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
|
||||
24
.github/workflows/pr-python-model-tests.yml
vendored
24
.github/workflows/pr-python-model-tests.yml
vendored
@@ -17,13 +17,8 @@ env:
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# API keys for testing
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
@@ -77,7 +72,7 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
@@ -128,22 +123,9 @@ jobs:
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
"""improved index
|
||||
|
||||
Revision ID: 3bd4c84fe72f
|
||||
Revises: 8f43500ee275
|
||||
Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
down_revision = "8f43500ee275"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# NOTE:
|
||||
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
# Also add a stored tsvector column for chat_session.description
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add index
|
||||
|
||||
Revision ID: 8f43500ee275
|
||||
Revises: da42808081e3
|
||||
Create Date: 2025-02-24 17:35:33.072714
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8f43500ee275"
|
||||
down_revision = "da42808081e3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the index
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -1,120 +0,0 @@
|
||||
"""migrate jira connectors to new format
|
||||
|
||||
Revision ID: da42808081e3
|
||||
Revises: f13db29f3101
|
||||
Create Date: 2025-02-24 11:24:54.396040
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da42808081e3"
|
||||
down_revision = "f13db29f3101"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
for connector_id, old_config in jira_connectors:
|
||||
if not old_config:
|
||||
continue
|
||||
|
||||
# Extract project key from URL if it exists
|
||||
new_config: dict[str, str | None] = {}
|
||||
if project_url := old_config.get("jira_project_url"):
|
||||
# Parse the URL to get base and project
|
||||
try:
|
||||
jira_base, project_key = extract_jira_project(project_url)
|
||||
new_config = {"jira_base_url": jira_base, "project_key": project_key}
|
||||
except ValueError:
|
||||
# If URL parsing fails, just use the URL as the base
|
||||
new_config = {
|
||||
"jira_base_url": project_url.split("/projects/")[0],
|
||||
"project_key": None,
|
||||
}
|
||||
else:
|
||||
# For connectors without a project URL, we need admin intervention
|
||||
# Mark these for review
|
||||
print(
|
||||
f"WARNING: Jira connector {connector_id} has no project URL configured"
|
||||
)
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config back to the old format
|
||||
for connector_id, new_config in jira_connectors:
|
||||
if not new_config:
|
||||
continue
|
||||
|
||||
old_config = {}
|
||||
base_url = new_config.get("jira_base_url")
|
||||
project_key = new_config.get("project_key")
|
||||
|
||||
if base_url and project_key:
|
||||
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
|
||||
elif base_url:
|
||||
old_config = {"jira_project_url": base_url}
|
||||
else:
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :old_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "old_config": old_config},
|
||||
)
|
||||
@@ -5,9 +5,11 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -16,8 +18,10 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -31,19 +35,24 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -51,9 +60,9 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -2,7 +2,6 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,10 +21,8 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -38,8 +35,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -112,17 +107,6 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -138,7 +122,6 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -158,12 +141,6 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -180,16 +157,11 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=minimal_chat_sessions,
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -200,12 +172,6 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -227,9 +193,6 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -240,12 +203,6 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -256,9 +213,6 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -42,9 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -55,19 +52,8 @@ def fetch_billing_information(
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
||||
@@ -200,35 +200,14 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
@@ -240,6 +219,25 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -36,7 +35,8 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,8 +2,6 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -15,7 +13,6 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -153,9 +150,8 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -177,7 +173,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -243,13 +239,13 @@ def send_user_email_invite(
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if MULTI_TENANT:
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User | None = None
|
||||
user: User
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@@ -420,20 +420,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
user = cast(User, await self.user_db.get_by_email(account_email))
|
||||
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@@ -444,36 +440,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
@@ -568,7 +554,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
|
||||
@@ -2,7 +2,6 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -132,9 +131,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
tenant_id = None
|
||||
else:
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None = None,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
|
||||
@@ -57,51 +57,6 @@ class TaskDependencyError(RuntimeError):
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
def revoke_tasks_blocking_deletion(
|
||||
redis_connector: RedisConnector, db_session: Session, app: Celery
|
||||
) -> None:
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
try:
|
||||
index_payload = redis_connector_index.payload
|
||||
if index_payload and index_payload.celery_task_id:
|
||||
app.control.revoke(index_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked indexing task {index_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking indexing task")
|
||||
|
||||
try:
|
||||
permissions_sync_payload = redis_connector.permissions.payload
|
||||
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
|
||||
app.control.revoke(permissions_sync_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
if prune_payload and prune_payload.celery_task_id:
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
|
||||
app.control.revoke(external_group_sync_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking external group sync task")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
@@ -109,7 +64,9 @@ def revoke_tasks_blocking_deletion(
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -119,7 +76,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# Prevent this task from overlapping with itself
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
@@ -156,38 +113,9 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# on the first error, we set a stop signal and revoke the dependent tasks
|
||||
# on subsequent errors, we hard reset blocking fences after our specified timeout
|
||||
# is exceeded
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
|
||||
if not redis_connector.stop.fenced:
|
||||
# one time revoke of celery tasks
|
||||
task_logger.info("Revoking any tasks blocking deletion.")
|
||||
revoke_tasks_blocking_deletion(
|
||||
redis_connector, db_session, self.app
|
||||
)
|
||||
redis_connector.stop.set_fence(True)
|
||||
redis_connector.stop.set_timeout()
|
||||
else:
|
||||
# stop signal already set
|
||||
if redis_connector.stop.timed_out:
|
||||
# waiting too long, just reset blocking fences
|
||||
task_logger.info(
|
||||
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
|
||||
)
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
redis_connector.prune.reset()
|
||||
redis_connector.permissions.reset()
|
||||
redis_connector.external_group_sync.reset()
|
||||
else:
|
||||
# just wait
|
||||
pass
|
||||
redis_connector.stop.set_fence(True)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
@@ -222,7 +150,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
@@ -343,7 +271,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -498,7 +426,7 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -538,7 +466,7 @@ def validate_connector_deletion_fences(
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
|
||||
@@ -221,7 +221,7 @@ def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -320,7 +320,7 @@ def try_creating_permissions_sync_task(
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
@@ -410,6 +410,7 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -509,7 +510,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
@@ -584,7 +585,7 @@ def update_external_document_permissions_task(
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -631,7 +632,7 @@ def validate_permission_sync_fences(
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
@@ -841,7 +842,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
|
||||
@@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
@@ -220,7 +220,7 @@ def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -306,7 +306,7 @@ def try_creating_external_group_sync_task(
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
@@ -392,6 +392,7 @@ def connector_external_group_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -493,7 +494,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
@@ -525,7 +526,7 @@ def validate_external_group_sync_fences(
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
|
||||
@@ -182,7 +182,7 @@ class SimpleJobResult:
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str
|
||||
tenant_id: str | None
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_task(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -890,7 +890,7 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
@@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
@@ -311,7 +311,7 @@ def validate_indexing_fence(
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -442,7 +442,7 @@ def try_creating_indexing_task(
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class Metric(BaseModel):
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str) -> None:
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
@@ -656,7 +656,7 @@ def build_job_id(
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
|
||||
@@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
@@ -333,7 +333,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -521,7 +521,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -615,7 +615,7 @@ def validate_pruning_fences(
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
|
||||
@@ -32,7 +32,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
|
||||
@@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -297,8 +297,7 @@ def cloud_beat_task_generator(
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
tenant_ids: list[str] | list[None] = []
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -326,8 +325,6 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -347,7 +344,6 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
@@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -523,7 +523,9 @@ def monitor_document_set_taskset(
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Optional
|
||||
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.setup import setup_logger
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -55,7 +55,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -68,6 +67,7 @@ def _get_connector_runner(
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
@@ -86,6 +86,7 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
@@ -240,7 +241,7 @@ def _check_failure_threshold(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -387,6 +388,7 @@ def _run_indexing(
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
@@ -679,7 +681,7 @@ def _run_indexing(
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
@@ -699,7 +701,7 @@ def run_indexing_entrypoint(
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if MULTI_TENANT:
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -30,9 +29,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
|
||||
@@ -213,12 +213,6 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@ from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -163,7 +161,7 @@ class OnyxConfluence(Confluence):
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
@@ -238,41 +236,9 @@ class OnyxConfluence(Confluence):
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||
yield from results
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
old_url_suffix = url_suffix
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
|
||||
# make sure we don't update the start by more than the amount
|
||||
# of results we were able to retrieve. The Confluence API has a
|
||||
# weird behavior where if you pass in a limit that is too large for
|
||||
# the configured server, it will artificially limit the amount of
|
||||
# results returned BUT will not apply this to the start parameter.
|
||||
# This will cause us to miss results.
|
||||
if url_suffix and "start" in url_suffix:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
adjusted_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "start", str(adjusted_start)
|
||||
)
|
||||
|
||||
# some APIs don't properly paginate, so we need to manually update the `start` param
|
||||
if auto_paginate and len(results) > 0:
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
updated_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
old_url_suffix, "start", str(updated_start)
|
||||
)
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
@@ -332,9 +298,7 @@ class OnyxConfluence(Confluence):
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
# endpoint doesn't properly paginate, so we need to manually update the `start` param
|
||||
# thus the auto_paginate flag
|
||||
for user_result in self._paginate_url(url, limit, auto_paginate=True):
|
||||
for user_result in self._paginate_url(url, limit):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
|
||||
@@ -2,10 +2,7 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -13,13 +10,13 @@ from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -27,7 +24,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
@@ -50,7 +47,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -81,7 +78,7 @@ def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
@@ -194,7 +191,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
@@ -282,32 +279,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
return parse_qs(parsed_url.query).get(param, [None])[0]
|
||||
|
||||
|
||||
def get_start_param_from_url(url: str) -> int:
|
||||
"""Get the start parameter from a url"""
|
||||
start_str = get_single_param_from_url(url, "start")
|
||||
if start_str is None:
|
||||
return 0
|
||||
return int(start_str)
|
||||
|
||||
|
||||
def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
"""Update a parameter in a path. Path should look something like:
|
||||
|
||||
/api/rest/users?start=0&limit=10
|
||||
"""
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -163,9 +164,13 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
@@ -179,6 +184,7 @@ def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
@@ -210,6 +216,7 @@ def validate_ccpair_for_user(
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -27,6 +27,8 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -163,10 +165,12 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -175,8 +179,9 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -198,6 +203,8 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -124,7 +124,7 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repo_name: str | None = None,
|
||||
repo_name: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
@@ -162,81 +162,53 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get organization first
|
||||
try:
|
||||
org = github_client.get_organization(self.repo_owner)
|
||||
return list(org.get_repos())
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = github_client.get_user(self.repo_owner)
|
||||
return list(user.get_repos())
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = (
|
||||
[self._get_github_repo(self.github_client)]
|
||||
if self.repo_name
|
||||
else self._get_all_repos(self.github_client)
|
||||
)
|
||||
repo = self._get_github_repo(self.github_client)
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
if self.include_prs:
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
if self.include_issues:
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
@@ -262,26 +234,16 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
if not self.repo_owner:
|
||||
if not self.repo_owner or not self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'repo_owner' must be provided."
|
||||
"Invalid connector settings: 'repo_owner' and 'repo_name' must be provided."
|
||||
)
|
||||
|
||||
try:
|
||||
if self.repo_name:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
try:
|
||||
org = self.github_client.get_organization(self.repo_owner)
|
||||
org.get_repos().totalCount # Just check if we can access repos
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = self.github_client.get_user(self.repo_owner)
|
||||
user.get_repos().totalCount # Just check if we can access repos
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
|
||||
except RateLimitExceededException:
|
||||
raise UnexpectedError(
|
||||
@@ -298,14 +260,9 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
if self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import openpyxl # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
@@ -45,15 +43,12 @@ def _extract_sections_basic(
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
if mime_type not in supported_file_types:
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
# ---------------------------
|
||||
# Google Sheets extraction
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
@@ -114,53 +109,7 @@ def _extract_sections_basic(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# ---------------------------
|
||||
# Microsoft Excel (.xlsx or .xls) extraction branch
|
||||
elif mime_type in [
|
||||
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
|
||||
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
|
||||
]:
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
|
||||
tmp.write(response)
|
||||
tmp_path = tmp.name
|
||||
|
||||
section_separator = "\n\n"
|
||||
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
|
||||
|
||||
# Work similarly to the xlsx_to_text function used for file connector
|
||||
# but returns Sections instead of a string
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"Sheet: {sheet.title}\n\n"
|
||||
+ section_separator.join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(
|
||||
min_row=1, values_only=True
|
||||
)
|
||||
if row
|
||||
)
|
||||
),
|
||||
)
|
||||
for sheet in workbook.worksheets
|
||||
]
|
||||
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting data from Excel file '{file['name']}': {e}"
|
||||
)
|
||||
return [
|
||||
Section(link=link, text="Error extracting data from Excel file")
|
||||
]
|
||||
|
||||
# ---------------------------
|
||||
# Export for Google Docs, PPT, and fallback for spreadsheets
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@@ -179,8 +128,6 @@ def _extract_sections_basic(
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
# ---------------------------
|
||||
# Plain text and Markdown files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
@@ -194,8 +141,6 @@ def _extract_sections_basic(
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
@@ -225,11 +170,7 @@ def _extract_sections_basic(
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
error_message = f"Unsupported file type: {mime_type}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
@@ -5,10 +5,6 @@ from typing import Any
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
|
||||
from onyx.connectors.onyx_jira.utils import build_jira_client
|
||||
from onyx.connectors.onyx_jira.utils import build_jira_url
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.onyx_jira.utils import get_comment_strs
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -159,8 +160,7 @@ def fetch_jira_issues_batch(
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
project_key: str | None = None,
|
||||
jira_project_url: str,
|
||||
comment_email_blacklist: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# if a ticket has one of the labels specified in this list, we will just
|
||||
@@ -169,12 +169,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
|
||||
self.jira_project = project_key
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
|
||||
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
|
||||
self._jira_client: JIRA | None = None
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
@@ -189,9 +188,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@property
|
||||
def quoted_jira_project(self) -> str:
|
||||
# Quote the project name to handle reserved words
|
||||
if not self.jira_project:
|
||||
return ""
|
||||
return f'"{self.jira_project}"'
|
||||
return f'"{self._jira_project}"'
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
@@ -200,14 +197,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_jql_query(self) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set"""
|
||||
if self.jira_project:
|
||||
return f"project = {self.quoted_jira_project}"
|
||||
return "" # Empty string means all accessible projects
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = self._get_jql_query()
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
@@ -234,10 +225,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
base_jql = self._get_jql_query()
|
||||
jql = (
|
||||
f"{base_jql} AND " if base_jql else ""
|
||||
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
f"project = {self.quoted_jira_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
@@ -260,7 +252,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = self._get_jql_query()
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _paginate_jql_search(
|
||||
@@ -287,63 +279,43 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if self._jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# If a specific project is set, validate it exists
|
||||
if self.jira_project:
|
||||
try:
|
||||
self.jira_client.project(self.jira_project)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if not self._jira_project:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'jira_project' must be provided."
|
||||
)
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credential appears to be expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
|
||||
)
|
||||
elif status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Jira project not found with key: {self.jira_project}"
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
try:
|
||||
self.jira_client.project(self._jira_project)
|
||||
|
||||
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
|
||||
else:
|
||||
# If no project specified, validate we can access the Jira API
|
||||
try:
|
||||
# Try to list projects to validate access
|
||||
self.jira_client.projects()
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credential appears to be expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Jira token does not have sufficient permissions to list projects (HTTP 403)."
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", None)
|
||||
|
||||
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credential appears to be expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
|
||||
)
|
||||
elif status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Jira project not found with key: {self._jira_project}"
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unexpected Jira error during validation: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = JiraConnector(
|
||||
jira_base_url=os.environ["JIRA_BASE_URL"],
|
||||
project_key=os.environ.get("JIRA_PROJECT_KEY"),
|
||||
comment_email_blacklist=[],
|
||||
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -72,7 +73,7 @@ def insert_api_key(
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
api_key = generate_api_key(tenant_id)
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
|
||||
@@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_updated))
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
@@ -962,7 +962,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
is_agentic=chat_message.is_agentic,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def search_chat_sessions(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
query: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
include_deleted: bool = False,
|
||||
include_onyxbot_flows: bool = False,
|
||||
) -> Tuple[List[ChatSession], bool]:
|
||||
"""
|
||||
Fast full-text search on ChatSession + ChatMessage using tsvectors.
|
||||
|
||||
If no query is provided, returns the most recent chat sessions.
|
||||
Otherwise, searches both chat messages and session descriptions.
|
||||
|
||||
Returns a tuple of (sessions, has_more) where has_more indicates if
|
||||
there are additional results beyond the requested page.
|
||||
"""
|
||||
offset_val = (page - 1) * page_size
|
||||
|
||||
# If no query, just return the most recent sessions
|
||||
if not query or not query.strip():
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
sessions = result.scalars().all()
|
||||
|
||||
has_more = len(sessions) > page_size
|
||||
if has_more:
|
||||
sessions = sessions[:page_size]
|
||||
|
||||
return list(sessions), has_more
|
||||
|
||||
# Otherwise, proceed with full-text search
|
||||
query = query.strip()
|
||||
|
||||
base_conditions = []
|
||||
if user_id is not None:
|
||||
base_conditions.append(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
base_conditions.append(ChatSession.deleted.is_(False))
|
||||
|
||||
message_tsv: ColumnClause = column("message_tsv")
|
||||
description_tsv: ColumnClause = column("description_tsv")
|
||||
|
||||
ts_query = func.plainto_tsquery("english", query)
|
||||
|
||||
description_session_ids = (
|
||||
select(ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(description_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
message_session_ids = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(message_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
combined_ids = description_session_ids.union(message_session_ids).alias(
|
||||
"combined_ids"
|
||||
)
|
||||
|
||||
final_stmt = (
|
||||
select(ChatSession)
|
||||
.join(combined_ids, ChatSession.id == combined_ids.c.id)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.distinct()
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
.options(joinedload(ChatSession.persona))
|
||||
)
|
||||
|
||||
session_objs = db_session.execute(final_stmt).scalars().all()
|
||||
|
||||
has_more = len(session_objs) > page_size
|
||||
if has_more:
|
||||
session_objs = session_objs[:page_size]
|
||||
|
||||
return list(session_objs), has_more
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
@@ -9,18 +8,15 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -35,12 +31,10 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
R = TypeVarTuple("R")
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
|
||||
) -> Select[tuple[*R]]:
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
@@ -104,52 +98,17 @@ def get_connector_credential_pairs_for_user(
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
if eager_load_user:
|
||||
assert (
|
||||
eager_load_credential
|
||||
), "eager_load_credential must be True if eager_load_user is True"
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
if eager_load_user:
|
||||
load_opts = load_opts.joinedload(Credential.user)
|
||||
stmt = stmt.options(load_opts)
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_connector_credential_pairs_for_user_parallel(
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
db_session,
|
||||
user,
|
||||
get_editable,
|
||||
ids,
|
||||
eager_load_connector,
|
||||
eager_load_credential,
|
||||
eager_load_user,
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
@@ -192,16 +151,6 @@ def get_cc_pair_groups_for_ids(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_cc_pair_groups_for_ids_parallel(
|
||||
cc_pair_ids: list[int],
|
||||
) -> list[UserGroup__ConnectorCredentialPair]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
|
||||
|
||||
|
||||
def get_connector_credential_pair_for_user(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
|
||||
@@ -24,7 +24,6 @@ from sqlalchemy.sql.expression import null
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
@@ -230,12 +229,12 @@ def get_document_connector_counts(
|
||||
|
||||
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
|
||||
|
||||
# Prepare a list of (connector_id, credential_id) tuples
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
@@ -261,16 +260,6 @@ def get_document_counts_for_cc_pairs(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_document_counts_for_cc_pairs_parallel(
|
||||
cc_pairs: list[ConnectorCredentialPairIdentifier],
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
|
||||
|
||||
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
|
||||
@@ -218,7 +218,6 @@ class SqlEngine:
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
@@ -258,11 +257,11 @@ class SqlEngine:
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [POSTGRES_DEFAULT_SCHEMA]
|
||||
return [None]
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
result = session.execute(
|
||||
@@ -417,7 +416,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -10,13 +9,9 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
@@ -373,33 +368,19 @@ def get_latest_index_attempts_by_status(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
T = TypeVarTuple("T")
|
||||
|
||||
|
||||
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
|
||||
return stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_latest_index_attempts(
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == status)
|
||||
|
||||
if only_finished:
|
||||
ids_stmt = _add_only_finished_clause(ids_stmt)
|
||||
if secondary_index:
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
else:
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
@@ -414,53 +395,7 @@ def get_latest_index_attempts(
|
||||
.where(IndexAttempt.id == ids_subquery.c.max_id)
|
||||
)
|
||||
|
||||
if only_finished:
|
||||
stmt = _add_only_finished_clause(stmt)
|
||||
|
||||
if eager_load_cc_pair:
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair),
|
||||
joinedload(IndexAttempt.error_rows),
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().unique().all()
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_latest_index_attempts_parallel(
|
||||
secondary_index: bool,
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_latest_index_attempts(
|
||||
secondary_index,
|
||||
db_session,
|
||||
eager_load_cc_pair,
|
||||
only_finished,
|
||||
)
|
||||
|
||||
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
)
|
||||
if only_finished:
|
||||
stmt = _add_only_finished_clause(stmt)
|
||||
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def count_index_attempts_for_connector(
|
||||
@@ -518,12 +453,37 @@ def get_paginated_index_attempts_for_cc_pair_id(
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
stmt = stmt.options(
|
||||
contains_eager(IndexAttempt.connector_credential_pair),
|
||||
joinedload(IndexAttempt.error_rows),
|
||||
)
|
||||
|
||||
return list(db_session.execute(stmt).scalars().unique().all())
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
)
|
||||
if only_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
if secondary_index:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.FUTURE
|
||||
)
|
||||
else:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_index_attempts_for_cc_pair(
|
||||
|
||||
@@ -25,7 +25,6 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
|
||||
@@ -100,14 +100,9 @@ def _add_user_filters(
|
||||
.correlate(Persona)
|
||||
)
|
||||
else:
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
where_clause |= Persona.is_public == True # noqa: E712
|
||||
where_clause &= Persona.is_visible == True # noqa: E712
|
||||
where_clause |= Persona__User.user_id == user.id
|
||||
|
||||
where_clause |= Persona.user_id == user.id
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
# Vespa's Document API.
|
||||
def get_document_chunk_ids(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
|
||||
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
)
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if MULTI_TENANT:
|
||||
if tenant_id and MULTI_TENANT:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
@@ -43,7 +43,7 @@ class IndexBatchParams:
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str
|
||||
tenant_id: str | None
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
@@ -270,7 +270,9 @@ class Updatable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
|
||||
@@ -468,7 +468,9 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
@@ -616,7 +618,7 @@ class VespaIndex(DocumentIndex):
|
||||
doc_id: str,
|
||||
*,
|
||||
chunk_count: int | None,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
@@ -659,7 +661,7 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
total_chunks_deleted = 0
|
||||
|
||||
@@ -158,8 +158,8 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -317,8 +317,8 @@ def index_doc_batch(
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
@@ -525,9 +525,9 @@ def build_indexing_pipeline(
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
|
||||
@@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tenant_id: str | None = None
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
@@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
|
||||
@@ -103,7 +103,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
|
||||
default_model="claude-3-7-sonnet-20250219",
|
||||
default_model="claude-3-5-sonnet-20241022",
|
||||
default_fast_model="claude-3-5-sonnet-20241022",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
|
||||
@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
setup_onyx(db_session, None)
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_session_by_message_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
@@ -410,11 +410,12 @@ def _build_qa_response_blocks(
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
@@ -481,6 +482,7 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
@@ -515,6 +517,7 @@ def build_slack_response_blocks(
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
@@ -114,7 +114,7 @@ def handle_generate_answer_button(
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -136,6 +136,7 @@ def handle_generate_answer_button(
|
||||
slack_channel_config=slack_channel_config,
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
tenant_id=client.tenant_id,
|
||||
channel=channel_id,
|
||||
logger=logger,
|
||||
feedback_reminder_id=None,
|
||||
@@ -150,10 +151,11 @@ def handle_slack_feedback(
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
@@ -244,7 +246,7 @@ def handle_followup_button(
|
||||
|
||||
tag_ids: list[str] = []
|
||||
group_ids: list[str] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
channel_name, is_dm = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel_id
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -109,6 +109,7 @@ def handle_message(
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -134,7 +135,9 @@ def handle_message(
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
slack_usage_report(
|
||||
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
@@ -215,7 +218,7 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if message_info.email:
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
@@ -241,5 +244,6 @@ def handle_message(
|
||||
channel=channel,
|
||||
logger=logger,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
return issue_with_regular_answer
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -71,6 +72,7 @@ def handle_regular_answer(
|
||||
channel: str,
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
@@ -85,7 +87,7 @@ def handle_regular_answer(
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
@@ -94,7 +96,7 @@ def handle_regular_answer(
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona
|
||||
if not persona:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
@@ -155,7 +157,7 @@ def handle_regular_answer(
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, onyx_user: User | None
|
||||
) -> ChatOnyxBotResponse:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=onyx_user,
|
||||
@@ -195,7 +197,7 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
@@ -359,6 +361,7 @@ def handle_regular_answer(
|
||||
return True
|
||||
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
channel_conf=channel_conf,
|
||||
|
||||
@@ -17,12 +17,10 @@ from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import POD_NAME
|
||||
@@ -37,7 +35,6 @@ from onyx.context.search.retrieval.search_runner import (
|
||||
download_nltk_data,
|
||||
)
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -93,7 +90,6 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -125,13 +121,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str] = set()
|
||||
self.tenant_ids: Set[str | None] = set()
|
||||
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
|
||||
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str, Lock] = {}
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
@@ -195,7 +191,7 @@ class SlackbotHandler:
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str, bot: SlackBot
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
@@ -253,12 +249,7 @@ class SlackbotHandler:
|
||||
- If yes, store them in self.tenant_ids and manage the socket connections.
|
||||
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
|
||||
"""
|
||||
|
||||
all_tenants = [
|
||||
tenant_id
|
||||
for tenant_id in get_all_tenant_ids()
|
||||
if tenant_id not in get_gated_tenants()
|
||||
]
|
||||
all_tenants = get_all_tenant_ids()
|
||||
|
||||
token: Token[str | None]
|
||||
|
||||
@@ -349,7 +340,7 @@ class SlackbotHandler:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Attempt to fetch Slack bots
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
@@ -387,7 +378,7 @@ class SlackbotHandler:
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str) -> None:
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
@@ -417,7 +408,7 @@ class SlackbotHandler:
|
||||
)
|
||||
|
||||
def start_socket_client(
|
||||
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
socket_client: TenantSocketModeClient = _get_socket_client(
|
||||
slack_bot_tokens, tenant_id, slack_bot_id
|
||||
@@ -425,7 +416,6 @@ class SlackbotHandler:
|
||||
|
||||
try:
|
||||
bot_info = socket_client.web_client.auth_test()
|
||||
|
||||
if bot_info["ok"]:
|
||||
bot_user_id = bot_info["user_id"]
|
||||
user_info = socket_client.web_client.users_info(user=bot_user_id)
|
||||
@@ -436,23 +426,9 @@ class SlackbotHandler:
|
||||
logger.info(
|
||||
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
|
||||
)
|
||||
except SlackApiError as e:
|
||||
# Only error out if we get a not_authed error
|
||||
if "not_authed" in str(e):
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.error(
|
||||
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
|
||||
"Error: {e}"
|
||||
)
|
||||
return
|
||||
# Log other Slack API errors but continue
|
||||
logger.error(
|
||||
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log other exceptions but continue
|
||||
logger.error(
|
||||
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
logger.warning(
|
||||
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
|
||||
# Append the event handler
|
||||
@@ -588,7 +564,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -682,6 +658,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
|
||||
user_id_to_post_confirmation=user_id,
|
||||
channel_id_to_post_confirmation=channel_id,
|
||||
thread_ts_to_post_confirmation=thread_ts,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_action_id(feedback_id)
|
||||
@@ -797,9 +774,8 @@ def process_message(
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
)
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
@@ -812,39 +788,50 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
token: Token[str | None] | None = None
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
follow_up = bool(
|
||||
slack_channel_config.channel_config
|
||||
and slack_channel_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
follow_up = bool(
|
||||
slack_channel_config.channel_config
|
||||
and slack_channel_config.channel_config.get("follow_up_tags")
|
||||
is not None
|
||||
)
|
||||
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_channel_config=slack_channel_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_channel_config=slack_channel_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender_id,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender_id,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@@ -903,7 +890,7 @@ def create_process_slack_event() -> (
|
||||
|
||||
|
||||
def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
@@ -4,8 +4,6 @@ import re
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -32,7 +30,7 @@ from onyx.configs.onyxbot_configs import (
|
||||
)
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -45,7 +43,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.text_processing import replace_whitespaces_w_space
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -572,7 +569,9 @@ def read_slack_thread(
|
||||
return thread_messages
|
||||
|
||||
|
||||
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
|
||||
def slack_usage_report(
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
|
||||
) -> None:
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
@@ -584,13 +583,14 @@ def slack_usage_report(action: str, sender_id: str | None, client: WebClient) ->
|
||||
logger.warning("Unable to find sender email")
|
||||
|
||||
if sender_email is not None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.USAGE,
|
||||
data={"action": action},
|
||||
user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -663,30 +663,9 @@ def get_feedback_visibility() -> FeedbackVisibility:
|
||||
|
||||
|
||||
class TenantSocketModeClient(SocketModeClient):
|
||||
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
|
||||
def __init__(
|
||||
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._tenant_id = tenant_id
|
||||
self.tenant_id = tenant_id
|
||||
self.slack_bot_id = slack_bot_id
|
||||
|
||||
@contextmanager
|
||||
def _set_tenant_context(self) -> Generator[None, None, None]:
|
||||
token = None
|
||||
try:
|
||||
if self._tenant_id:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self._tenant_id)
|
||||
yield
|
||||
finally:
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def enqueue_message(self, message: str) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().enqueue_message(message)
|
||||
|
||||
def process_message(self) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().process_message()
|
||||
|
||||
def run_message_listeners(self, message: dict, raw_message: str) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().run_message_listeners(message, raw_message)
|
||||
|
||||
@@ -16,10 +16,10 @@ class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str = tenant_id
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id: int = id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
@@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""We can limit the number of tasks generated here, which is useful to prevent
|
||||
one tenant from overwhelming the sync queue.
|
||||
|
||||
@@ -39,8 +39,8 @@ class RedisConnectorDelete:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPermissionSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,12 +52,12 @@ class RedisConnectorIndex:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
@@ -93,7 +93,10 @@ class RedisConnectorIndex:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexPayload | None:
|
||||
@@ -103,7 +106,9 @@ class RedisConnectorIndex:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
@@ -118,7 +123,10 @@ class RedisConnectorIndex:
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
|
||||
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
@@ -138,7 +146,10 @@ class RedisConnectorIndex:
|
||||
|
||||
def watchdog_signaled(self) -> bool:
|
||||
"""Check the state of the watchdog."""
|
||||
return bool(self.redis.exists(self.watchdog_key))
|
||||
if self.redis.exists(self.watchdog_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -149,7 +160,10 @@ class RedisConnectorIndex:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_connector_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -166,7 +180,10 @@ class RedisConnectorIndex:
|
||||
return False
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
return bool(self.redis.exists(self.generator_lock_key))
|
||||
if self.redis.exists(self.generator_lock_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPrune:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -5,21 +5,14 @@ class RedisConnectorStop:
|
||||
"""Manages interactions with redis for stop signaling. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_PREFIX = "connectorstop_fence"
|
||||
|
||||
# if this timeout is exceeded, the caller may decide to take more
|
||||
# drastic measures
|
||||
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
|
||||
TIMEOUT_TTL = 300
|
||||
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id: int = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.timeout_key: str = f"{self.TIMEOUT_PREFIX}_{id}"
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
@@ -35,22 +28,7 @@ class RedisConnectorStop:
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
|
||||
@property
|
||||
def timed_out(self) -> bool:
|
||||
if self.redis.exists(self.timeout_key):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def set_timeout(self) -> None:
|
||||
"""After calling this, call timed_out to determine if the timeout has been
|
||||
exceeded."""
|
||||
self.redis.set(f"{self.timeout_key}", 0, ex=self.TIMEOUT_TTL)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorStop.TIMEOUT_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
@@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
document set up to date over multiple batches.
|
||||
|
||||
@@ -14,8 +14,8 @@ class RedisObjectHelper(ABC):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: str):
|
||||
self._tenant_id: str = tenant_id
|
||||
def __init__(self, tenant_id: str | None, id: str):
|
||||
self._tenant_id: str | None = tenant_id
|
||||
self._id: str = id
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -87,7 +87,7 @@ class RedisObjectHelper(ABC):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
@@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
user group up to date over multiple batches.
|
||||
|
||||
@@ -37,15 +37,13 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_indexable_chunks(
|
||||
preprocessed_docs: list[dict],
|
||||
tenant_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
|
||||
ids_to_documents = {}
|
||||
chunks = []
|
||||
@@ -88,7 +86,7 @@ def _create_indexable_chunks(
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=preprocessed_doc["title_embedding"],
|
||||
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
|
||||
tenant_id=tenant_id,
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
boost=DEFAULT_BOOST,
|
||||
@@ -113,7 +111,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
|
||||
)
|
||||
is_editable_for_current_user = editable_cc_pair is not None
|
||||
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
document_count_info_list = list(
|
||||
get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
],
|
||||
cc_pair_identifiers=[cc_pair_identifier],
|
||||
)
|
||||
)
|
||||
documents_indexed = (
|
||||
@@ -620,7 +620,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
try:
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session)
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
@@ -646,6 +646,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValidationError as e:
|
||||
# If validation fails, delete the connector and commit the changes
|
||||
# Ensures we don't leave invalid connectors in the database
|
||||
@@ -659,14 +660,10 @@ def associate_credential_to_connector(
|
||||
)
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
delete_connector(db_session, connector_id)
|
||||
db_session.commit()
|
||||
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
|
||||
raise HTTPException(status_code=500, detail="Unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -72,31 +72,25 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector import update_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import delete_service_account_credentials
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_latest_index_attempts
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_parallel
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
@@ -125,8 +119,8 @@ from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -584,8 +578,6 @@ def get_connector_status(
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
|
||||
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
|
||||
@@ -640,35 +632,23 @@ def get_connector_indexing_status(
|
||||
# Additional checks are done to make sure the connector and credential still exist.
|
||||
# TODO: make this one query ... possibly eager load or wrap in a read transaction
|
||||
# to avoid the complexity of trying to error check throughout the function
|
||||
|
||||
# see https://stackoverflow.com/questions/75758327/
|
||||
# sqlalchemy-method-connection-for-bind-is-already-in-progress
|
||||
# for why we can't pass in the current db_session to these functions
|
||||
(
|
||||
cc_pairs,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
) = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
# Gets the connector/credential pairs for the user
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, get_editable, None, True, True, True),
|
||||
),
|
||||
(
|
||||
# Gets the most recent index attempt for each connector/credential pair
|
||||
get_latest_index_attempts_parallel,
|
||||
(secondary_index, True, False),
|
||||
),
|
||||
(
|
||||
# Gets the most recent FINISHED index attempt for each connector/credential pair
|
||||
get_latest_index_attempts_parallel,
|
||||
(secondary_index, True, True),
|
||||
),
|
||||
]
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
|
||||
cc_pair_identifiers = [
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
|
||||
latest_index_attempts = get_latest_index_attempts(
|
||||
secondary_index=secondary_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
|
||||
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
|
||||
|
||||
cc_pair_to_latest_index_attempt = {
|
||||
(
|
||||
@@ -678,60 +658,31 @@ def get_connector_indexing_status(
|
||||
for index_attempt in latest_index_attempts
|
||||
}
|
||||
|
||||
cc_pair_to_latest_finished_index_attempt = {
|
||||
(
|
||||
index_attempt.connector_credential_pair.connector_id,
|
||||
index_attempt.connector_credential_pair.credential_id,
|
||||
): index_attempt
|
||||
for index_attempt in latest_finished_index_attempts
|
||||
}
|
||||
|
||||
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
get_document_counts_for_cc_pairs_parallel,
|
||||
(
|
||||
[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
get_cc_pair_groups_for_ids_parallel,
|
||||
([cc_pair.id for cc_pair in cc_pairs],),
|
||||
),
|
||||
]
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
|
||||
group_cc_pair_relationships = cast(
|
||||
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
|
||||
)
|
||||
|
||||
cc_pair_to_document_cnt = {
|
||||
(connector_id, credential_id): cnt
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
|
||||
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
|
||||
db_session=db_session,
|
||||
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
|
||||
)
|
||||
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
|
||||
for relationship in group_cc_pair_relationships:
|
||||
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
|
||||
relationship.user_group_id
|
||||
)
|
||||
|
||||
connector_to_cc_pair_ids: dict[int, list[int]] = {}
|
||||
for cc_pair in cc_pairs:
|
||||
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
|
||||
search_settings: SearchSettings | None = None
|
||||
if not secondary_index:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
get_search_settings = (
|
||||
get_secondary_search_settings
|
||||
if secondary_index
|
||||
else get_current_search_settings
|
||||
)
|
||||
search_settings = get_search_settings(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
@@ -754,8 +705,11 @@ def get_connector_indexing_status(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
|
||||
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair.id,
|
||||
secondary_index=secondary_index,
|
||||
only_finished=True,
|
||||
)
|
||||
|
||||
indexing_statuses.append(
|
||||
@@ -764,9 +718,7 @@ def get_connector_indexing_status(
|
||||
name=cc_pair.name,
|
||||
in_progress=in_progress,
|
||||
cc_pair_status=cc_pair.status,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
connector, connector_to_cc_pair_ids.get(connector.id, [])
|
||||
),
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
access_type=cc_pair.access_type,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
@@ -902,6 +854,7 @@ def create_connector_with_mock_credential(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
from onyx.db.credentials import swap_credentials_connector
|
||||
from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
@@ -99,11 +100,13 @@ def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
validate_ccpair_for_user(
|
||||
credential_swap_req.connector_id,
|
||||
credential_swap_req.new_credential_id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
|
||||
@@ -83,9 +83,7 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
source: DocumentSource
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(
|
||||
cls, connector: Connector, credential_ids: list[int] | None = None
|
||||
) -> "ConnectorSnapshot":
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
@@ -94,10 +92,9 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=(
|
||||
credential_ids
|
||||
or [association.credential.id for association in connector.credentials]
|
||||
),
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
indexing_start=connector.indexing_start,
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
|
||||
@@ -49,7 +49,6 @@ def get_folders(
|
||||
name=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
time_created=chat_session.time_created.isoformat(),
|
||||
time_updated=chat_session.time_updated.isoformat(),
|
||||
shared_status=chat_session.shared_status,
|
||||
folder_id=folder.id,
|
||||
)
|
||||
|
||||
@@ -343,8 +343,7 @@ def list_bot_configs(
|
||||
]
|
||||
|
||||
|
||||
MAX_SLACK_PAGES = 5
|
||||
SLACK_API_CHANNELS_PER_PAGE = 100
|
||||
MAX_CHANNELS = 200
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -356,8 +355,8 @@ def get_all_channels_from_slack_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches channels the bot is a member of from the Slack API.
|
||||
Handles pagination with a limit to avoid excessive API calls.
|
||||
Fetches all channels from the Slack API.
|
||||
If the workspace has 200 or more channels, we raise an error.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
if not tokens or "bot_token" not in tokens:
|
||||
@@ -366,60 +365,28 @@ def get_all_channels_from_slack_api(
|
||||
)
|
||||
|
||||
client = WebClient(token=tokens["bot_token"])
|
||||
all_channels = []
|
||||
next_cursor = None
|
||||
current_page = 0
|
||||
|
||||
try:
|
||||
# Use users_conversations with limited pagination
|
||||
while current_page < MAX_SLACK_PAGES:
|
||||
current_page += 1
|
||||
|
||||
# Make API call with cursor if we have one
|
||||
if next_cursor:
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
cursor=next_cursor,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
else:
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
|
||||
# Add channels to our list
|
||||
if "channels" in response and response["channels"]:
|
||||
all_channels.extend(response["channels"])
|
||||
|
||||
# Check if we need to paginate
|
||||
if (
|
||||
"response_metadata" in response
|
||||
and "next_cursor" in response["response_metadata"]
|
||||
):
|
||||
next_cursor = response["response_metadata"]["next_cursor"]
|
||||
if next_cursor:
|
||||
if current_page == MAX_SLACK_PAGES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Workspace has too many channels to paginate over in this call.",
|
||||
)
|
||||
continue
|
||||
|
||||
# If we get here, no more pages
|
||||
break
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=MAX_CHANNELS,
|
||||
)
|
||||
|
||||
channels = [
|
||||
SlackChannel(id=channel["id"], name=channel["name"])
|
||||
for channel in all_channels
|
||||
for channel in response["channels"]
|
||||
]
|
||||
|
||||
if len(channels) == MAX_CHANNELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
except SlackApiError as e:
|
||||
# Handle rate limiting or other API errors
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error fetching channels from Slack API: {str(e)}",
|
||||
|
||||
@@ -147,11 +147,9 @@ def list_threads(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
@@ -47,7 +44,6 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
@@ -69,13 +65,10 @@ from onyx.secondary_llm_flows.chat_session_naming import (
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
from onyx.server.query_and_chat.models import ChatMessageIdentifier
|
||||
from onyx.server.query_and_chat.models import ChatRenameRequest
|
||||
from onyx.server.query_and_chat.models import ChatSearchResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetailResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionGroup
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionSummary
|
||||
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatSessionID
|
||||
@@ -119,7 +112,6 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -802,84 +794,3 @@ def fetch_chat_file(
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
return StreamingResponse(file_io, media_type=media_type)
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_chats(
|
||||
query: str | None = Query(None),
|
||||
page: int = Query(1),
|
||||
page_size: int = Query(10),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSearchResponse:
|
||||
"""
|
||||
Search for chat sessions based on the provided query.
|
||||
If no query is provided, returns recent chat sessions.
|
||||
"""
|
||||
|
||||
# Use the enhanced database function for chat search
|
||||
chat_sessions, has_more = search_chat_sessions(
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
include_deleted=False,
|
||||
include_onyxbot_flows=False,
|
||||
)
|
||||
|
||||
# Group chat sessions by time period
|
||||
today = datetime.datetime.now().date()
|
||||
yesterday = today - timedelta(days=1)
|
||||
this_week = today - timedelta(days=7)
|
||||
this_month = today - timedelta(days=30)
|
||||
|
||||
today_chats: list[ChatSessionSummary] = []
|
||||
yesterday_chats: list[ChatSessionSummary] = []
|
||||
this_week_chats: list[ChatSessionSummary] = []
|
||||
this_month_chats: list[ChatSessionSummary] = []
|
||||
older_chats: list[ChatSessionSummary] = []
|
||||
|
||||
for session in chat_sessions:
|
||||
session_date = session.time_created.date()
|
||||
|
||||
chat_summary = ChatSessionSummary(
|
||||
id=session.id,
|
||||
name=session.description,
|
||||
persona_id=session.persona_id,
|
||||
time_created=session.time_created,
|
||||
shared_status=session.shared_status,
|
||||
folder_id=session.folder_id,
|
||||
current_alternate_model=session.current_alternate_model,
|
||||
current_temperature_override=session.temperature_override,
|
||||
)
|
||||
|
||||
if session_date == today:
|
||||
today_chats.append(chat_summary)
|
||||
elif session_date == yesterday:
|
||||
yesterday_chats.append(chat_summary)
|
||||
elif session_date > this_week:
|
||||
this_week_chats.append(chat_summary)
|
||||
elif session_date > this_month:
|
||||
this_month_chats.append(chat_summary)
|
||||
else:
|
||||
older_chats.append(chat_summary)
|
||||
|
||||
# Create groups
|
||||
groups = []
|
||||
if today_chats:
|
||||
groups.append(ChatSessionGroup(title="Today", chats=today_chats))
|
||||
if yesterday_chats:
|
||||
groups.append(ChatSessionGroup(title="Yesterday", chats=yesterday_chats))
|
||||
if this_week_chats:
|
||||
groups.append(ChatSessionGroup(title="This Week", chats=this_week_chats))
|
||||
if this_month_chats:
|
||||
groups.append(ChatSessionGroup(title="This Month", chats=this_month_chats))
|
||||
if older_chats:
|
||||
groups.append(ChatSessionGroup(title="Older", chats=older_chats))
|
||||
|
||||
return ChatSearchResponse(
|
||||
groups=groups,
|
||||
has_more=has_more,
|
||||
next_page=page + 1 if has_more else None,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -181,7 +180,6 @@ class ChatSessionDetails(BaseModel):
|
||||
name: str | None
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
time_updated: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
@@ -242,7 +240,6 @@ class ChatMessageDetail(BaseModel):
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None = None
|
||||
is_agentic: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
@@ -285,35 +282,3 @@ class AdminSearchRequest(BaseModel):
|
||||
|
||||
class AdminSearchResponse(BaseModel):
|
||||
documents: list[SearchDoc]
|
||||
|
||||
|
||||
class ChatSessionSummary(BaseModel):
|
||||
id: UUID
|
||||
name: str | None = None
|
||||
persona_id: int | None = None
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
current_temperature_override: float | None = None
|
||||
|
||||
|
||||
class ChatSessionGroup(BaseModel):
|
||||
title: str
|
||||
chats: list[ChatSessionSummary]
|
||||
|
||||
|
||||
class ChatSearchResponse(BaseModel):
|
||||
groups: list[ChatSessionGroup]
|
||||
has_more: bool
|
||||
next_page: int | None = None
|
||||
|
||||
|
||||
class ChatSearchRequest(BaseModel):
|
||||
query: str | None = None
|
||||
page: int = 1
|
||||
page_size: int = 10
|
||||
|
||||
|
||||
class CreateChatResponse(BaseModel):
|
||||
chat_session_id: str
|
||||
|
||||
@@ -159,7 +159,6 @@ def get_user_search_sessions(
|
||||
name=sessions_with_documents_dict[search.id],
|
||||
persona_id=search.persona_id,
|
||||
time_created=search.time_created.isoformat(),
|
||||
time_updated=search.time_updated.isoformat(),
|
||||
shared_status=search.shared_status,
|
||||
folder_id=search.folder_id,
|
||||
current_alternate_model=search.current_alternate_model,
|
||||
|
||||
@@ -4,9 +4,7 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
@@ -51,10 +49,9 @@ class Settings(BaseModel):
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
auto_scroll: bool | None = False
|
||||
query_history_type: QueryHistoryType | None = None
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
needs_reindexing: bool
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
tenant_id: str | None = None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -46,7 +45,6 @@ def load_settings() -> Settings:
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def setup_onyx(
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Setup Onyx for a particular tenant. In the Single Tenant case, it will set it up for the default schema
|
||||
|
||||
@@ -37,7 +37,7 @@ langchainhub==0.1.21
|
||||
langgraph==0.2.72
|
||||
langgraph-checkpoint==2.0.13
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.61.16
|
||||
litellm==1.60.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
@@ -12,5 +12,5 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.61.16
|
||||
litellm==1.60.2
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
@@ -260,7 +260,7 @@ def get_documents_for_tenant_connector(
|
||||
def search_for_document(
|
||||
index_name: str,
|
||||
document_id: str | None = None,
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
tenant_id: str | None = None,
|
||||
max_hits: int | None = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
yql_query = f"select * from sources {index_name}"
|
||||
@@ -507,9 +507,9 @@ def get_number_of_chunks_we_think_exist(
|
||||
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
||||
def __init__(self, tenant_id: str | None = None):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
self.tenant_id = tenant_id
|
||||
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
|
||||
self.index_name = get_index_name(self.tenant_id)
|
||||
|
||||
def sample_document_counts(self) -> None:
|
||||
@@ -603,7 +603,7 @@ class VespaDebugging:
|
||||
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
|
||||
|
||||
def search_for_document(
|
||||
self, document_id: str | None = None, tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
self, document_id: str | None = None, tenant_id: str | None = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id, tenant_id)
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Modify sys.path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -75,7 +74,7 @@ def _unsafe_deletion(
|
||||
for document in documents:
|
||||
document_index.delete_single(
|
||||
doc_id=document.id,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
tenant_id=None,
|
||||
chunk_count=document.chunk_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -97,9 +96,7 @@ def main() -> None:
|
||||
try:
|
||||
print(f"Deleting document {doc_id} in Vespa")
|
||||
chunks_deleted = vespa_index.delete_single(
|
||||
doc_id,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
chunk_count=document.chunk_count,
|
||||
doc_id, tenant_id=None, chunk_count=document.chunk_count
|
||||
)
|
||||
if chunks_deleted > 0:
|
||||
print(
|
||||
|
||||
@@ -18,7 +18,5 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
def get_current_tenant_id() -> str:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id is None:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
raise RuntimeError("Tenant ID is not set. This should never happen.")
|
||||
return tenant_id
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_confluence_connector_basic(
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert txt_doc.primary_owners[0].email == "chris@danswer.ai"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
|
||||
@@ -10,8 +10,7 @@ from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
@pytest.fixture
|
||||
def jira_connector() -> JiraConnector:
|
||||
connector = JiraConnector(
|
||||
jira_base_url="https://danswerai.atlassian.net",
|
||||
project_key="AS",
|
||||
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
|
||||
comment_email_blacklist=[],
|
||||
)
|
||||
connector.load_credentials(
|
||||
|
||||
@@ -4,10 +4,6 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.web.connector import WEB_CONNECTOR_VALID_SETTINGS
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
|
||||
EXPECTED_QUOTE = (
|
||||
"If you can't explain it to a six year old, you don't understand it yourself."
|
||||
)
|
||||
|
||||
|
||||
# NOTE(rkuo): we will probably need to adjust this test to point at our own test site
|
||||
# to avoid depending on a third party site
|
||||
@@ -15,7 +11,7 @@ EXPECTED_QUOTE = (
|
||||
def web_connector(request: pytest.FixtureRequest) -> WebConnector:
|
||||
scroll_before_scraping = request.param
|
||||
connector = WebConnector(
|
||||
base_url="https://quotes.toscrape.com/scroll",
|
||||
base_url="https://developer.onewelcome.com",
|
||||
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
|
||||
scroll_before_scraping=scroll_before_scraping,
|
||||
)
|
||||
@@ -32,7 +28,7 @@ def test_web_connector_scroll(web_connector: WebConnector) -> None:
|
||||
|
||||
assert len(all_docs) == 1
|
||||
doc = all_docs[0]
|
||||
assert EXPECTED_QUOTE in doc.sections[0].text
|
||||
assert "Onegini Identity Cloud" in doc.sections[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("web_connector", [False], indirect=True)
|
||||
@@ -45,4 +41,4 @@ def test_web_connector_no_scroll(web_connector: WebConnector) -> None:
|
||||
|
||||
assert len(all_docs) == 1
|
||||
doc = all_docs[0]
|
||||
assert EXPECTED_QUOTE not in doc.sections[0].text
|
||||
assert "Onegini Identity Cloud" not in doc.sections[0].text
|
||||
|
||||
@@ -71,13 +71,12 @@ def litellm_embedding_model() -> EmbeddingModel:
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("LITELLM_API_KEY"),
|
||||
api_key=os.getenv("LITE_LLM_API_KEY"),
|
||||
provider_type=EmbeddingProvider.LITELLM,
|
||||
api_url=os.getenv("LITELLM_API_URL"),
|
||||
api_url=os.getenv("LITE_LLM_API_URL"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="re-enable when we can get the correct litellm key and url")
|
||||
def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536)
|
||||
@@ -118,11 +117,6 @@ def azure_embedding_model() -> EmbeddingModel:
|
||||
)
|
||||
|
||||
|
||||
def test_azure_embedding(azure_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, azure_embedding_model, 1536)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, azure_embedding_model, 1536)
|
||||
|
||||
|
||||
# NOTE (chris): this test doesn't work, and I do not know why
|
||||
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
|
||||
# """NOTE: this test relies on a very low rate limit for the Azure API +
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
services:
|
||||
indexing_model_server:
|
||||
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_ONLY=True
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- CLIENT_EMBEDDING_TIMEOUT=${CLIENT_EMBEDDING_TIMEOUT:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
ports:
|
||||
- "9000:9000" # <-- Add this line to expose the port to the host
|
||||
|
||||
volumes:
|
||||
indexing_huggingface_model_cache:
|
||||
@@ -68,28 +68,6 @@ const nextConfig = {
|
||||
},
|
||||
];
|
||||
},
|
||||
async rewrites() {
|
||||
return [
|
||||
{
|
||||
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/docs/:path*`,
|
||||
},
|
||||
{
|
||||
source: "/api/docs", // if you also need the exact /api/docs
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/docs`,
|
||||
},
|
||||
{
|
||||
source: "/openapi.json",
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/openapi.json`,
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
// Sentry configuration for error monitoring:
|
||||
|
||||
89
web/package-lock.json
generated
89
web/package-lock.json
generated
@@ -70,8 +70,6 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
@@ -11743,54 +11741,6 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz",
|
||||
"integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA=="
|
||||
},
|
||||
"node_modules/hast-util-sanitize": {
|
||||
"version": "5.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-sanitize/-/hast-util-sanitize-5.0.2.tgz",
|
||||
"integrity": "sha512-3yTWghByc50aGS7JlGhk61SPenfE/p1oaFeNwkOOyrscaOkMGrcW9+Cy/QAIOBpZxP1yqDIzFMR0+Np0i0+usg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@ungap/structured-clone": "^1.0.0",
|
||||
"unist-util-position": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-html/-/hast-util-to-html-9.0.5.tgz",
|
||||
"integrity": "sha512-OguPdidb+fbHQSU4Q4ZiLKnzWo8Wwsf5bZfbvu7//a9oTYoqD/fWpe96NuHkoS9h0ccGOTe0C4NGXdtS0iObOw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/unist": "^3.0.0",
|
||||
"ccount": "^2.0.0",
|
||||
"comma-separated-tokens": "^2.0.0",
|
||||
"hast-util-whitespace": "^3.0.0",
|
||||
"html-void-elements": "^3.0.0",
|
||||
"mdast-util-to-hast": "^13.0.0",
|
||||
"property-information": "^7.0.0",
|
||||
"space-separated-tokens": "^2.0.0",
|
||||
"stringify-entities": "^4.0.0",
|
||||
"zwitch": "^2.0.4"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html/node_modules/property-information": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.0.0.tgz",
|
||||
"integrity": "sha512-7D/qOz/+Y4X/rzSB6jKxKUsQnphO046ei8qxG59mtM3RG3DHgTK81HrxrmoDVINJb8NKT5ZsRbwHvQ6B68Iyhg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-jsx-runtime": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.0.tgz",
|
||||
@@ -11969,16 +11919,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/html-void-elements": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
|
||||
"integrity": "sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/html-webpack-plugin": {
|
||||
"version": "5.6.3",
|
||||
"resolved": "https://registry.npmjs.org/html-webpack-plugin/-/html-webpack-plugin-5.6.3.tgz",
|
||||
@@ -19185,35 +19125,6 @@
|
||||
"unist-util-visit": "^5.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-sanitize": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/rehype-sanitize/-/rehype-sanitize-6.0.0.tgz",
|
||||
"integrity": "sha512-CsnhKNsyI8Tub6L4sm5ZFsme4puGfc6pYylvXo1AeqaGbjOYyzNv3qZPwvs0oMJ39eryyeOdmxwUIo94IpEhqg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-sanitize": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-stringify": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/rehype-stringify/-/rehype-stringify-10.0.1.tgz",
|
||||
"integrity": "sha512-k9ecfXHmIPuFVI61B9DeLPN0qFHfawM6RsuX48hoqlaKSF61RskNjSm1lI8PhBEM0MRdLxVVm4WmTqJQccH9mA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-to-html": "^9.0.0",
|
||||
"unified": "^11.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/relateurl": {
|
||||
"version": "0.2.7",
|
||||
"resolved": "https://registry.npmjs.org/relateurl/-/relateurl-0.2.7.tgz",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user