Compare commits

..

10 Commits

Author SHA1 Message Date
pablonyx
25d9266da4 update 2025-02-26 08:48:35 -08:00
Weves
23073d91b9 reduce number of chars to index for search 2025-02-25 19:27:50 -08:00
Chris Weaver
f767b1f476 Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale

* Remove line

* Better log message

* Adjust log
2025-02-25 19:22:52 -08:00
pablonyx
9ffc8cb2c4 k 2025-02-25 18:15:49 -08:00
pablonyx
98bfb58147 Handle bad slack configurations– multi tenant (#4118)
* k

* quick nit

* k

* k
2025-02-25 22:22:54 +00:00
evan-danswer
6ce810e957 faster indexing status at scale plus minor cleanups (#4081)
* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
2025-02-25 21:22:26 +00:00
pablonyx
07b0b57b31 (nit) bump timeout 2025-02-25 14:10:30 -08:00
pablonyx
118cdd7701 Chat search (#4113)
* add chat search

* don't add the bible

* base functional

* k

* k

* functioning

* functioning well

* functioning well

* k

* delete bible

* quick cleanup

* quick cleanup

* k

* fixed frontend hooks

* delete bible

* nit

* nit

* nit

* fix build

* k

* improved debouncing

* address comments

* fix alembic

* k
2025-02-25 20:49:46 +00:00
rkuo-danswer
ac83b4c365 validate connector deletion (#4108)
* validate connector deletion

* fixes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 20:35:21 +00:00
pablonyx
fa408ff447 add 3.7 (#4116) 2025-02-25 12:41:40 -08:00
24 changed files with 610 additions and 137 deletions

View File

@@ -17,10 +17,11 @@ depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
op.execute(
"""
CREATE INDEX idx_chat_message_message_lower
ON chat_message (LOWER(message))
ON chat_message (LOWER(substring(message, 1, 1500)))
"""
)

View File

@@ -224,7 +224,7 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)

View File

@@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -8,16 +8,21 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -109,6 +114,7 @@ def check_for_connector_deletion_task(
) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -120,6 +126,21 @@ def check_for_connector_deletion_task(
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
@@ -243,6 +264,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -475,3 +497,171 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -342,6 +342,9 @@ class OnyxRedisSignals:
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
"signal:block_validate_connector_deletion_fences"
)
class OnyxRedisConstants:

View File

@@ -11,6 +11,8 @@ from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.utils.logger import setup_logger
@@ -161,7 +163,7 @@ class OnyxConfluence(Confluence):
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -236,9 +238,41 @@ class OnyxConfluence(Confluence):
raise e
# yield the results individually
yield from next_response.get("results", [])
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
url_suffix = next_response.get("_links", {}).get("next")
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
def paginated_cql_retrieval(
self,
@@ -298,7 +332,9 @@ class OnyxConfluence(Confluence):
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
# Example response:
# {
# 'user': {

View File

@@ -2,7 +2,10 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
logger = setup_logger()
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
def attachment_to_content(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
return parse_qs(parsed_url.query).get(param, [None])[0]
def get_start_param_from_url(url: str) -> int:
"""Get the start parameter from a url"""
start_str = get_single_param_from_url(url, "start")
if start_str is None:
return 0
return int(start_str)
def update_param_in_path(path: str, param: str, value: str) -> str:
"""Update a parameter in a path. Path should look something like:
/api/rest/users?start=0&limit=10
"""
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@@ -8,15 +9,18 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -31,10 +35,12 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -98,17 +104,52 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
def get_connector_credential_pairs(
@@ -151,6 +192,16 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,

View File

@@ -24,6 +24,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -229,12 +230,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
stmt = (
select(
@@ -260,6 +261,16 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -218,6 +218,7 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@@ -9,9 +10,13 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -368,19 +373,33 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@@ -395,7 +414,53 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
return db_session.execute(stmt).scalars().all()
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def count_index_attempts_for_connector(
@@ -453,37 +518,12 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return list(db_session.execute(stmt).scalars().unique().all())
def get_index_attempts_for_cc_pair(

View File

@@ -103,7 +103,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
default_model="claude-3-5-sonnet-20241022",
default_model="claude-3-7-sonnet-20250219",
default_fast_model="claude-3-5-sonnet-20241022",
),
WellKnownLLMProviderDescriptor(

View File

@@ -17,10 +17,12 @@ from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.chat.models import ThreadMessage
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import POD_NAME
@@ -249,7 +251,12 @@ class SlackbotHandler:
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
all_tenants = [
tenant_id
for tenant_id in get_all_tenant_ids()
if tenant_id not in get_gated_tenants()
]
token: Token[str | None]
@@ -416,6 +423,7 @@ class SlackbotHandler:
try:
bot_info = socket_client.web_client.auth_test()
if bot_info["ok"]:
bot_user_id = bot_info["user_id"]
user_info = socket_client.web_client.users_info(user=bot_user_id)
@@ -426,9 +434,23 @@ class SlackbotHandler:
logger.info(
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
)
except SlackApiError as e:
# Only error out if we get a not_authed error
if "not_authed" in str(e):
self.tenant_ids.add(tenant_id)
logger.error(
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
"Error: {e}"
)
return
# Log other Slack API errors but continue
logger.error(
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
except Exception as e:
logger.warning(
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
# Log other exceptions but continue
logger.error(
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
# Append the event handler

View File

@@ -33,6 +33,12 @@ class RedisConnectorDelete:
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -41,6 +47,8 @@ class RedisConnectorDelete:
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -77,6 +85,20 @@ class RedisConnectorDelete:
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
@@ -141,6 +163,7 @@ class RedisConnectorDelete:
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@@ -153,6 +176,9 @@ class RedisConnectorDelete:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -93,10 +93,7 @@ class RedisConnectorIndex:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorIndexPayload | None:
@@ -106,9 +103,7 @@ class RedisConnectorIndex:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
def set_fence(
self,
@@ -123,10 +118,7 @@ class RedisConnectorIndex:
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
@@ -146,10 +138,7 @@ class RedisConnectorIndex:
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
return bool(self.redis.exists(self.watchdog_key))
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -160,10 +149,7 @@ class RedisConnectorIndex:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def set_connector_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -180,10 +166,7 @@ class RedisConnectorIndex:
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
return bool(self.redis.exists(self.generator_lock_key))
def set_generator_complete(self, payload: int | None) -> None:
if not payload:

View File

@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
)
is_editable_for_current_user = editable_cc_pair is not None
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
document_count_info_list = list(
get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=[cc_pair_identifier],
cc_pairs=[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
],
)
)
documents_indexed = (

View File

@@ -72,25 +72,31 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector import update_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_latest_index_attempts
from onyx.db.index_attempt import get_latest_index_attempts_by_status
from onyx.db.index_attempt import get_latest_index_attempts_parallel
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
@@ -119,8 +125,8 @@ from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -578,6 +584,8 @@ def get_connector_status(
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
eager_load_connector=True,
eager_load_credential=True,
)
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
@@ -632,23 +640,35 @@ def get_connector_indexing_status(
# Additional checks are done to make sure the connector and credential still exist.
# TODO: make this one query ... possibly eager load or wrap in a read transaction
# to avoid the complexity of trying to error check throughout the function
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
)
cc_pair_identifiers = [
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
for cc_pair in cc_pairs
]
latest_index_attempts = get_latest_index_attempts(
secondary_index=secondary_index,
db_session=db_session,
# see https://stackoverflow.com/questions/75758327/
# sqlalchemy-method-connection-for-bind-is-already-in-progress
# for why we can't pass in the current db_session to these functions
(
cc_pairs,
latest_index_attempts,
latest_finished_index_attempts,
) = run_functions_tuples_in_parallel(
[
(
# Gets the connector/credential pairs for the user
get_connector_credential_pairs_for_user_parallel,
(user, get_editable, None, True, True, True),
),
(
# Gets the most recent index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, False),
),
(
# Gets the most recent FINISHED index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, True),
),
]
)
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
cc_pair_to_latest_index_attempt = {
(
@@ -658,31 +678,60 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
cc_pair_to_latest_finished_index_attempt = {
(
index_attempt.connector_credential_pair.connector_id,
index_attempt.connector_credential_pair.credential_id,
): index_attempt
for index_attempt in latest_finished_index_attempts
}
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
[
(
get_document_counts_for_cc_pairs_parallel,
(
[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
for cc_pair in cc_pairs
],
),
),
(
get_cc_pair_groups_for_ids_parallel,
([cc_pair.id for cc_pair in cc_pairs],),
),
]
)
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
group_cc_pair_relationships = cast(
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
)
cc_pair_to_document_cnt = {
(connector_id, credential_id): cnt
for connector_id, credential_id, cnt in document_count_info
}
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
db_session=db_session,
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
)
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
for relationship in group_cc_pair_relationships:
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
connector_to_cc_pair_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
get_search_settings = (
get_secondary_search_settings
if secondary_index
else get_current_search_settings
)
search_settings = get_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@@ -705,11 +754,8 @@ def get_connector_indexing_status(
(connector.id, credential.id)
)
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
db_session=db_session,
connector_credential_pair_id=cc_pair.id,
secondary_index=secondary_index,
only_finished=True,
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(connector.id, credential.id)
)
indexing_statuses.append(
@@ -718,7 +764,9 @@ def get_connector_indexing_status(
name=cc_pair.name,
in_progress=in_progress,
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
connector=ConnectorSnapshot.from_connector_db_model(
connector, connector_to_cc_pair_ids.get(connector.id, [])
),
credential=CredentialSnapshot.from_credential_db_model(credential),
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",

View File

@@ -83,7 +83,9 @@ class ConnectorSnapshot(ConnectorBase):
source: DocumentSource
@classmethod
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
def from_connector_db_model(
cls, connector: Connector, credential_ids: list[int] | None = None
) -> "ConnectorSnapshot":
return ConnectorSnapshot(
id=connector.id,
name=connector.name,
@@ -92,9 +94,10 @@ class ConnectorSnapshot(ConnectorBase):
connector_specific_config=connector.connector_specific_config,
refresh_freq=connector.refresh_freq,
prune_freq=connector.prune_freq,
credential_ids=[
association.credential.id for association in connector.credentials
],
credential_ids=(
credential_ids
or [association.credential.id for association in connector.credentials]
),
indexing_start=connector.indexing_start,
time_created=connector.time_created,
time_updated=connector.time_updated,

View File

@@ -37,7 +37,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.60.2
litellm==1.61.16
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45

View File

@@ -12,5 +12,5 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.60.2
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0

View File

@@ -68,6 +68,28 @@ const nextConfig = {
},
];
},
async rewrites() {
return [
{
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs/:path*`,
},
{
source: "/api/docs", // if you also need the exact /api/docs
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs`,
},
{
source: "/openapi.json",
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/openapi.json`,
},
];
},
};
// Sentry configuration for error monitoring:

View File

@@ -714,10 +714,11 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"claude-2.1": "Claude 2.1",
"claude-2.0": "Claude 2.0",
"claude-instant-1.2": "Claude Instant 1.2",
"claude-3-5-sonnet-20240620": "Claude 3.5 Sonnet",
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet (New)",
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
"claude-3-5-sonnet-20240620": "Claude 3.5 Sonnet (June 2024)",
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet",
"claude-3-7-sonnet-20250219": "Claude 3.7 Sonnet",
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet",
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet",
"claude-3-5-haiku-20241022": "Claude 3.5 Haiku",
"claude-3-5-haiku@20241022": "Claude 3.5 Haiku",
"claude-3.5-haiku@20241022": "Claude 3.5 Haiku",

View File

@@ -71,6 +71,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
// standard claude names
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
@@ -88,6 +89,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
// google gemini model names
"gemini-1.5-pro",
"gemini-1.5-flash",

View File

@@ -18,7 +18,7 @@ async function verifyAdminPageNavigation(
try {
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle, {
timeout: 3000,
timeout: 5000,
});
} catch (error) {
console.error(