mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 04:05:48 +00:00
Compare commits
10 Commits
chat_searc
...
fix_openap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25d9266da4 | ||
|
|
23073d91b9 | ||
|
|
f767b1f476 | ||
|
|
9ffc8cb2c4 | ||
|
|
98bfb58147 | ||
|
|
6ce810e957 | ||
|
|
07b0b57b31 | ||
|
|
118cdd7701 | ||
|
|
ac83b4c365 | ||
|
|
fa408ff447 |
@@ -17,10 +17,11 @@ depends_on = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(message))
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -224,7 +224,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
|
||||
@@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
"""This is a redis specific way to build a list of tasks in a queue and return them
|
||||
as a set.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
|
||||
@@ -8,16 +8,21 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
@@ -109,6 +114,7 @@ def check_for_connector_deletion_task(
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -120,6 +126,21 @@ def check_for_connector_deletion_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating connector deletion fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -243,6 +264,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.delete.set_active()
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@@ -475,3 +497,171 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# validate all existing connector deletion jobs
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_connector_deletion_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.delete.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.delete.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.delete.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_connector_deletion_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.delete.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.delete.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -342,6 +342,9 @@ class OnyxRedisSignals:
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
|
||||
@@ -11,6 +11,8 @@ from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -161,7 +163,7 @@ class OnyxConfluence(Confluence):
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
@@ -236,9 +238,41 @@ class OnyxConfluence(Confluence):
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||
yield from results
|
||||
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
old_url_suffix = url_suffix
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
|
||||
# make sure we don't update the start by more than the amount
|
||||
# of results we were able to retrieve. The Confluence API has a
|
||||
# weird behavior where if you pass in a limit that is too large for
|
||||
# the configured server, it will artificially limit the amount of
|
||||
# results returned BUT will not apply this to the start parameter.
|
||||
# This will cause us to miss results.
|
||||
if url_suffix and "start" in url_suffix:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
adjusted_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "start", str(adjusted_start)
|
||||
)
|
||||
|
||||
# some APIs don't properly paginate, so we need to manually update the `start` param
|
||||
if auto_paginate and len(results) > 0:
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
updated_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
old_url_suffix, "start", str(updated_start)
|
||||
)
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
@@ -298,7 +332,9 @@ class OnyxConfluence(Confluence):
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
for user_result in self._paginate_url(url, limit):
|
||||
# endpoint doesn't properly paginate, so we need to manually update the `start` param
|
||||
# thus the auto_paginate flag
|
||||
for user_result in self._paginate_url(url, limit, auto_paginate=True):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
|
||||
@@ -2,7 +2,10 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
return parse_qs(parsed_url.query).get(param, [None])[0]
|
||||
|
||||
|
||||
def get_start_param_from_url(url: str) -> int:
|
||||
"""Get the start parameter from a url"""
|
||||
start_str = get_single_param_from_url(url, "start")
|
||||
if start_str is None:
|
||||
return 0
|
||||
return int(start_str)
|
||||
|
||||
|
||||
def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
"""Update a parameter in a path. Path should look something like:
|
||||
|
||||
/api/rest/users?start=0&limit=10
|
||||
"""
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
@@ -8,15 +9,18 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -31,10 +35,12 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
R = TypeVarTuple("R")
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
|
||||
) -> Select[tuple[*R]]:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
@@ -98,17 +104,52 @@ def get_connector_credential_pairs_for_user(
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
if eager_load_user:
|
||||
assert (
|
||||
eager_load_credential
|
||||
), "eager_load_credential must be True if eager_load_user is True"
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
if eager_load_user:
|
||||
load_opts = load_opts.joinedload(Credential.user)
|
||||
stmt = stmt.options(load_opts)
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_connector_credential_pairs_for_user_parallel(
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
db_session,
|
||||
user,
|
||||
get_editable,
|
||||
ids,
|
||||
eager_load_connector,
|
||||
eager_load_credential,
|
||||
eager_load_user,
|
||||
)
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
@@ -151,6 +192,16 @@ def get_cc_pair_groups_for_ids(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_cc_pair_groups_for_ids_parallel(
|
||||
cc_pair_ids: list[int],
|
||||
) -> list[UserGroup__ConnectorCredentialPair]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
|
||||
|
||||
|
||||
def get_connector_credential_pair_for_user(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
|
||||
@@ -24,6 +24,7 @@ from sqlalchemy.sql.expression import null
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
@@ -229,12 +230,12 @@ def get_document_connector_counts(
|
||||
|
||||
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
|
||||
|
||||
# Prepare a list of (connector_id, credential_id) tuples
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
@@ -260,6 +261,16 @@ def get_document_counts_for_cc_pairs(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_document_counts_for_cc_pairs_parallel(
|
||||
cc_pairs: list[ConnectorCredentialPairIdentifier],
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
|
||||
|
||||
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
|
||||
@@ -218,6 +218,7 @@ class SqlEngine:
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -9,9 +10,13 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
@@ -368,19 +373,33 @@ def get_latest_index_attempts_by_status(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
T = TypeVarTuple("T")
|
||||
|
||||
|
||||
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
|
||||
return stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_latest_index_attempts(
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
|
||||
if secondary_index:
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
else:
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == status)
|
||||
|
||||
if only_finished:
|
||||
ids_stmt = _add_only_finished_clause(ids_stmt)
|
||||
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
@@ -395,7 +414,53 @@ def get_latest_index_attempts(
|
||||
.where(IndexAttempt.id == ids_subquery.c.max_id)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
if only_finished:
|
||||
stmt = _add_only_finished_clause(stmt)
|
||||
|
||||
if eager_load_cc_pair:
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair),
|
||||
joinedload(IndexAttempt.error_rows),
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().unique().all()
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_latest_index_attempts_parallel(
|
||||
secondary_index: bool,
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_latest_index_attempts(
|
||||
secondary_index,
|
||||
db_session,
|
||||
eager_load_cc_pair,
|
||||
only_finished,
|
||||
)
|
||||
|
||||
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
)
|
||||
if only_finished:
|
||||
stmt = _add_only_finished_clause(stmt)
|
||||
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def count_index_attempts_for_connector(
|
||||
@@ -453,37 +518,12 @@ def get_paginated_index_attempts_for_cc_pair_id(
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
stmt = stmt.options(
|
||||
contains_eager(IndexAttempt.connector_credential_pair),
|
||||
joinedload(IndexAttempt.error_rows),
|
||||
)
|
||||
if only_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
if secondary_index:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.FUTURE
|
||||
)
|
||||
else:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
return list(db_session.execute(stmt).scalars().unique().all())
|
||||
|
||||
|
||||
def get_index_attempts_for_cc_pair(
|
||||
|
||||
@@ -103,7 +103,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
|
||||
default_model="claude-3-5-sonnet-20241022",
|
||||
default_model="claude-3-7-sonnet-20250219",
|
||||
default_fast_model="claude-3-5-sonnet-20241022",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
|
||||
@@ -17,10 +17,12 @@ from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import POD_NAME
|
||||
@@ -249,7 +251,12 @@ class SlackbotHandler:
|
||||
- If yes, store them in self.tenant_ids and manage the socket connections.
|
||||
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
|
||||
"""
|
||||
all_tenants = get_all_tenant_ids()
|
||||
|
||||
all_tenants = [
|
||||
tenant_id
|
||||
for tenant_id in get_all_tenant_ids()
|
||||
if tenant_id not in get_gated_tenants()
|
||||
]
|
||||
|
||||
token: Token[str | None]
|
||||
|
||||
@@ -416,6 +423,7 @@ class SlackbotHandler:
|
||||
|
||||
try:
|
||||
bot_info = socket_client.web_client.auth_test()
|
||||
|
||||
if bot_info["ok"]:
|
||||
bot_user_id = bot_info["user_id"]
|
||||
user_info = socket_client.web_client.users_info(user=bot_user_id)
|
||||
@@ -426,9 +434,23 @@ class SlackbotHandler:
|
||||
logger.info(
|
||||
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
|
||||
)
|
||||
except SlackApiError as e:
|
||||
# Only error out if we get a not_authed error
|
||||
if "not_authed" in str(e):
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.error(
|
||||
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
|
||||
"Error: {e}"
|
||||
)
|
||||
return
|
||||
# Log other Slack API errors but continue
|
||||
logger.error(
|
||||
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
# Log other exceptions but continue
|
||||
logger.error(
|
||||
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
|
||||
# Append the event handler
|
||||
|
||||
@@ -33,6 +33,12 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -41,6 +47,8 @@ class RedisConnectorDelete:
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@@ -77,6 +85,20 @@ class RedisConnectorDelete:
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
@@ -141,6 +163,7 @@ class RedisConnectorDelete:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@@ -153,6 +176,9 @@ class RedisConnectorDelete:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -93,10 +93,7 @@ class RedisConnectorIndex:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexPayload | None:
|
||||
@@ -106,9 +103,7 @@ class RedisConnectorIndex:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
@@ -123,10 +118,7 @@ class RedisConnectorIndex:
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
@@ -146,10 +138,7 @@ class RedisConnectorIndex:
|
||||
|
||||
def watchdog_signaled(self) -> bool:
|
||||
"""Check the state of the watchdog."""
|
||||
if self.redis.exists(self.watchdog_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.watchdog_key))
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -160,10 +149,7 @@ class RedisConnectorIndex:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
def set_connector_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -180,10 +166,7 @@ class RedisConnectorIndex:
|
||||
return False
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
if self.redis.exists(self.generator_lock_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.generator_lock_key))
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
|
||||
@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
|
||||
)
|
||||
is_editable_for_current_user = editable_cc_pair is not None
|
||||
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
document_count_info_list = list(
|
||||
get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=[cc_pair_identifier],
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
documents_indexed = (
|
||||
|
||||
@@ -72,25 +72,31 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector import update_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import delete_service_account_credentials
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_latest_index_attempts
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_parallel
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
@@ -119,8 +125,8 @@ from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -578,6 +584,8 @@ def get_connector_status(
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
|
||||
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
|
||||
@@ -632,23 +640,35 @@ def get_connector_indexing_status(
|
||||
# Additional checks are done to make sure the connector and credential still exist.
|
||||
# TODO: make this one query ... possibly eager load or wrap in a read transaction
|
||||
# to avoid the complexity of trying to error check throughout the function
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
|
||||
cc_pair_identifiers = [
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
|
||||
latest_index_attempts = get_latest_index_attempts(
|
||||
secondary_index=secondary_index,
|
||||
db_session=db_session,
|
||||
# see https://stackoverflow.com/questions/75758327/
|
||||
# sqlalchemy-method-connection-for-bind-is-already-in-progress
|
||||
# for why we can't pass in the current db_session to these functions
|
||||
(
|
||||
cc_pairs,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
) = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
# Gets the connector/credential pairs for the user
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, get_editable, None, True, True, True),
|
||||
),
|
||||
(
|
||||
# Gets the most recent index attempt for each connector/credential pair
|
||||
get_latest_index_attempts_parallel,
|
||||
(secondary_index, True, False),
|
||||
),
|
||||
(
|
||||
# Gets the most recent FINISHED index attempt for each connector/credential pair
|
||||
get_latest_index_attempts_parallel,
|
||||
(secondary_index, True, True),
|
||||
),
|
||||
]
|
||||
)
|
||||
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
|
||||
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
|
||||
|
||||
cc_pair_to_latest_index_attempt = {
|
||||
(
|
||||
@@ -658,31 +678,60 @@ def get_connector_indexing_status(
|
||||
for index_attempt in latest_index_attempts
|
||||
}
|
||||
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
cc_pair_to_latest_finished_index_attempt = {
|
||||
(
|
||||
index_attempt.connector_credential_pair.connector_id,
|
||||
index_attempt.connector_credential_pair.credential_id,
|
||||
): index_attempt
|
||||
for index_attempt in latest_finished_index_attempts
|
||||
}
|
||||
|
||||
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
get_document_counts_for_cc_pairs_parallel,
|
||||
(
|
||||
[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
get_cc_pair_groups_for_ids_parallel,
|
||||
([cc_pair.id for cc_pair in cc_pairs],),
|
||||
),
|
||||
]
|
||||
)
|
||||
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
|
||||
group_cc_pair_relationships = cast(
|
||||
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
|
||||
)
|
||||
|
||||
cc_pair_to_document_cnt = {
|
||||
(connector_id, credential_id): cnt
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
|
||||
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
|
||||
db_session=db_session,
|
||||
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
|
||||
)
|
||||
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
|
||||
for relationship in group_cc_pair_relationships:
|
||||
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
|
||||
relationship.user_group_id
|
||||
)
|
||||
|
||||
search_settings: SearchSettings | None = None
|
||||
if not secondary_index:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
connector_to_cc_pair_ids: dict[int, list[int]] = {}
|
||||
for cc_pair in cc_pairs:
|
||||
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
|
||||
|
||||
get_search_settings = (
|
||||
get_secondary_search_settings
|
||||
if secondary_index
|
||||
else get_current_search_settings
|
||||
)
|
||||
search_settings = get_search_settings(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
@@ -705,11 +754,8 @@ def get_connector_indexing_status(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
|
||||
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair.id,
|
||||
secondary_index=secondary_index,
|
||||
only_finished=True,
|
||||
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
|
||||
indexing_statuses.append(
|
||||
@@ -718,7 +764,9 @@ def get_connector_indexing_status(
|
||||
name=cc_pair.name,
|
||||
in_progress=in_progress,
|
||||
cc_pair_status=cc_pair.status,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
connector, connector_to_cc_pair_ids.get(connector.id, [])
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
access_type=cc_pair.access_type,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
|
||||
@@ -83,7 +83,9 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
source: DocumentSource
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
def from_connector_db_model(
|
||||
cls, connector: Connector, credential_ids: list[int] | None = None
|
||||
) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
@@ -92,9 +94,10 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
credential_ids=(
|
||||
credential_ids
|
||||
or [association.credential.id for association in connector.credentials]
|
||||
),
|
||||
indexing_start=connector.indexing_start,
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
|
||||
@@ -37,7 +37,7 @@ langchainhub==0.1.21
|
||||
langgraph==0.2.72
|
||||
langgraph-checkpoint==2.0.13
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.60.2
|
||||
litellm==1.61.16
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
@@ -12,5 +12,5 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.60.2
|
||||
litellm==1.61.16
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
@@ -68,6 +68,28 @@ const nextConfig = {
|
||||
},
|
||||
];
|
||||
},
|
||||
async rewrites() {
|
||||
return [
|
||||
{
|
||||
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/docs/:path*`,
|
||||
},
|
||||
{
|
||||
source: "/api/docs", // if you also need the exact /api/docs
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/docs`,
|
||||
},
|
||||
{
|
||||
source: "/openapi.json",
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/openapi.json`,
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
// Sentry configuration for error monitoring:
|
||||
|
||||
@@ -714,10 +714,11 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
|
||||
"claude-2.1": "Claude 2.1",
|
||||
"claude-2.0": "Claude 2.0",
|
||||
"claude-instant-1.2": "Claude Instant 1.2",
|
||||
"claude-3-5-sonnet-20240620": "Claude 3.5 Sonnet",
|
||||
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet (New)",
|
||||
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
|
||||
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
|
||||
"claude-3-5-sonnet-20240620": "Claude 3.5 Sonnet (June 2024)",
|
||||
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet",
|
||||
"claude-3-7-sonnet-20250219": "Claude 3.7 Sonnet",
|
||||
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet",
|
||||
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet",
|
||||
"claude-3-5-haiku-20241022": "Claude 3.5 Haiku",
|
||||
"claude-3-5-haiku@20241022": "Claude 3.5 Haiku",
|
||||
"claude-3.5-haiku@20241022": "Claude 3.5 Haiku",
|
||||
|
||||
@@ -71,6 +71,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
// standard claude names
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
@@ -88,6 +89,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
// google gemini model names
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
|
||||
@@ -18,7 +18,7 @@ async function verifyAdminPageNavigation(
|
||||
|
||||
try {
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle, {
|
||||
timeout: 3000,
|
||||
timeout: 5000,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(
|
||||
|
||||
Reference in New Issue
Block a user