mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 18:25:45 +00:00
Compare commits
1 Commits
fix_openap
...
silence_mu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ba722b379 |
@@ -1,31 +0,0 @@
|
||||
"""add index
|
||||
|
||||
Revision ID: 8f43500ee275
|
||||
Revises: da42808081e3
|
||||
Create Date: 2025-02-24 17:35:33.072714
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8f43500ee275"
|
||||
down_revision = "da42808081e3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the index
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -92,8 +92,7 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue and return them
|
||||
as a set.
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
|
||||
@@ -8,21 +8,16 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
@@ -114,7 +109,6 @@ def check_for_connector_deletion_task(
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -126,21 +120,6 @@ def check_for_connector_deletion_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating connector deletion fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -264,7 +243,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.delete.set_active()
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@@ -497,171 +475,3 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# validate all existing connector deletion jobs
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_connector_deletion_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.delete.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.delete.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.delete.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_connector_deletion_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.delete.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.delete.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -342,9 +342,6 @@ class OnyxRedisSignals:
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
|
||||
@@ -11,8 +11,6 @@ from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -163,7 +161,7 @@ class OnyxConfluence(Confluence):
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
@@ -238,41 +236,9 @@ class OnyxConfluence(Confluence):
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||
yield from results
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
old_url_suffix = url_suffix
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
|
||||
# make sure we don't update the start by more than the amount
|
||||
# of results we were able to retrieve. The Confluence API has a
|
||||
# weird behavior where if you pass in a limit that is too large for
|
||||
# the configured server, it will artificially limit the amount of
|
||||
# results returned BUT will not apply this to the start parameter.
|
||||
# This will cause us to miss results.
|
||||
if url_suffix and "start" in url_suffix:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
adjusted_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "start", str(adjusted_start)
|
||||
)
|
||||
|
||||
# some APIs don't properly paginate, so we need to manually update the `start` param
|
||||
if auto_paginate and len(results) > 0:
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
updated_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
old_url_suffix, "start", str(updated_start)
|
||||
)
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
@@ -332,9 +298,7 @@ class OnyxConfluence(Confluence):
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
# endpoint doesn't properly paginate, so we need to manually update the `start` param
|
||||
# thus the auto_paginate flag
|
||||
for user_result in self._paginate_url(url, limit, auto_paginate=True):
|
||||
for user_result in self._paginate_url(url, limit):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
|
||||
@@ -2,10 +2,7 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -13,13 +10,13 @@ from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -27,7 +24,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
@@ -50,7 +47,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -81,7 +78,7 @@ def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
@@ -194,7 +191,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
@@ -282,32 +279,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
return parse_qs(parsed_url.query).get(param, [None])[0]
|
||||
|
||||
|
||||
def get_start_param_from_url(url: str) -> int:
|
||||
"""Get the start parameter from a url"""
|
||||
start_str = get_single_param_from_url(url, "start")
|
||||
if start_str is None:
|
||||
return 0
|
||||
return int(start_str)
|
||||
|
||||
|
||||
def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
"""Update a parameter in a path. Path should look something like:
|
||||
|
||||
/api/rest/users?start=0&limit=10
|
||||
"""
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import literal
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import union_all
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def search_chat_sessions(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
query: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
include_deleted: bool = False,
|
||||
include_onyxbot_flows: bool = False,
|
||||
) -> Tuple[List[ChatSession], bool]:
|
||||
"""
|
||||
Search for chat sessions based on the provided query.
|
||||
If no query is provided, returns recent chat sessions.
|
||||
|
||||
Returns a tuple of (chat_sessions, has_more)
|
||||
"""
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# If no search query, we use standard SQLAlchemy pagination
|
||||
if not query or not query.strip():
|
||||
stmt = select(ChatSession)
|
||||
if user_id:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(offset).limit(page_size + 1)
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
if has_more:
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return list(chat_sessions), has_more
|
||||
|
||||
words = query.lower().strip().split()
|
||||
|
||||
# Message mach subquery
|
||||
message_matches = []
|
||||
for word in words:
|
||||
word_like = f"%{word}%"
|
||||
message_match: Select = (
|
||||
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
|
||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.where(func.lower(ChatMessage.message).like(word_like))
|
||||
)
|
||||
|
||||
if user_id:
|
||||
message_match = message_match.where(ChatSession.user_id == user_id)
|
||||
|
||||
message_matches.append(message_match)
|
||||
|
||||
if message_matches:
|
||||
message_matches_query = union_all(*message_matches).alias("message_matches")
|
||||
else:
|
||||
return [], False
|
||||
|
||||
# Description matches
|
||||
description_match: Select = select(
|
||||
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
|
||||
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
|
||||
|
||||
if user_id:
|
||||
description_match = description_match.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
description_match = description_match.where(ChatSession.deleted.is_(False))
|
||||
|
||||
# Combine all match sources
|
||||
combined_matches = union_all(
|
||||
message_matches_query.select(), description_match
|
||||
).alias("combined_matches")
|
||||
|
||||
# Use CTE to group and get max rank
|
||||
session_ranks = (
|
||||
select(
|
||||
combined_matches.c.chat_session_id,
|
||||
func.max(combined_matches.c.search_rank).label("rank"),
|
||||
)
|
||||
.group_by(combined_matches.c.chat_session_id)
|
||||
.alias("session_ranks")
|
||||
)
|
||||
|
||||
# Get ranked sessions with pagination
|
||||
ranked_query = (
|
||||
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
|
||||
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
|
||||
.offset(offset)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
|
||||
result = ranked_query.all()
|
||||
|
||||
# Extract session IDs and ranks
|
||||
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
|
||||
session_ids = list(session_ids_with_ranks.keys())
|
||||
|
||||
if not session_ids:
|
||||
return [], False
|
||||
|
||||
# Now, let's query the actual ChatSession objects using the IDs
|
||||
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
# Full objects with eager loading
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
# Sort based on above ranking
|
||||
chat_sessions = sorted(
|
||||
chat_sessions,
|
||||
key=lambda session: (
|
||||
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
|
||||
session.time_created.timestamp() * -1, # Then by time (newest first)
|
||||
),
|
||||
)
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
if has_more:
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return chat_sessions, has_more
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
@@ -9,18 +8,15 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -35,12 +31,10 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
R = TypeVarTuple("R")
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
|
||||
) -> Select[tuple[*R]]:
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
@@ -104,52 +98,17 @@ def get_connector_credential_pairs_for_user(
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
if eager_load_user:
|
||||
assert (
|
||||
eager_load_credential
|
||||
), "eager_load_credential must be True if eager_load_user is True"
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
if eager_load_user:
|
||||
load_opts = load_opts.joinedload(Credential.user)
|
||||
stmt = stmt.options(load_opts)
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_connector_credential_pairs_for_user_parallel(
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
db_session,
|
||||
user,
|
||||
get_editable,
|
||||
ids,
|
||||
eager_load_connector,
|
||||
eager_load_credential,
|
||||
eager_load_user,
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
@@ -192,16 +151,6 @@ def get_cc_pair_groups_for_ids(
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_cc_pair_groups_for_ids_parallel(
|
||||
cc_pair_ids: list[int],
|
||||
) -> list[UserGroup__ConnectorCredentialPair]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
|
||||
|
||||
|
||||
def get_connector_credential_pair_for_user(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
|
||||
@@ -24,7 +24,6 @@ from sqlalchemy.sql.expression import null
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
@@ -230,12 +229,12 @@ def get_document_connector_counts(
|
||||
|
||||
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
|
||||
|
||||
# Prepare a list of (connector_id, credential_id) tuples
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
@@ -261,16 +260,6 @@ def get_document_counts_for_cc_pairs(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_document_counts_for_cc_pairs_parallel(
|
||||
cc_pairs: list[ConnectorCredentialPairIdentifier],
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
|
||||
|
||||
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
|
||||
@@ -218,7 +218,6 @@ class SqlEngine:
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -10,13 +9,9 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
@@ -373,33 +368,19 @@ def get_latest_index_attempts_by_status(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
T = TypeVarTuple("T")
|
||||
|
||||
|
||||
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
|
||||
return stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_latest_index_attempts(
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == status)
|
||||
|
||||
if only_finished:
|
||||
ids_stmt = _add_only_finished_clause(ids_stmt)
|
||||
if secondary_index:
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
else:
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
@@ -414,53 +395,7 @@ def get_latest_index_attempts(
|
||||
.where(IndexAttempt.id == ids_subquery.c.max_id)
|
||||
)
|
||||
|
||||
if only_finished:
|
||||
stmt = _add_only_finished_clause(stmt)
|
||||
|
||||
if eager_load_cc_pair:
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair),
|
||||
joinedload(IndexAttempt.error_rows),
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().unique().all()
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_latest_index_attempts_parallel(
|
||||
secondary_index: bool,
|
||||
eager_load_cc_pair: bool = False,
|
||||
only_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
with get_session_context_manager() as db_session:
|
||||
return get_latest_index_attempts(
|
||||
secondary_index,
|
||||
db_session,
|
||||
eager_load_cc_pair,
|
||||
only_finished,
|
||||
)
|
||||
|
||||
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
)
|
||||
if only_finished:
|
||||
stmt = _add_only_finished_clause(stmt)
|
||||
|
||||
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
|
||||
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def count_index_attempts_for_connector(
|
||||
@@ -518,12 +453,37 @@ def get_paginated_index_attempts_for_cc_pair_id(
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
stmt = stmt.options(
|
||||
contains_eager(IndexAttempt.connector_credential_pair),
|
||||
joinedload(IndexAttempt.error_rows),
|
||||
)
|
||||
|
||||
return list(db_session.execute(stmt).scalars().unique().all())
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session: Session,
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
only_finished: bool = True,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
)
|
||||
if only_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
if secondary_index:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.FUTURE
|
||||
)
|
||||
else:
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_index_attempts_for_cc_pair(
|
||||
|
||||
@@ -17,12 +17,10 @@ from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import POD_NAME
|
||||
@@ -251,12 +249,7 @@ class SlackbotHandler:
|
||||
- If yes, store them in self.tenant_ids and manage the socket connections.
|
||||
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
|
||||
"""
|
||||
|
||||
all_tenants = [
|
||||
tenant_id
|
||||
for tenant_id in get_all_tenant_ids()
|
||||
if tenant_id not in get_gated_tenants()
|
||||
]
|
||||
all_tenants = get_all_tenant_ids()
|
||||
|
||||
token: Token[str | None]
|
||||
|
||||
@@ -423,7 +416,6 @@ class SlackbotHandler:
|
||||
|
||||
try:
|
||||
bot_info = socket_client.web_client.auth_test()
|
||||
|
||||
if bot_info["ok"]:
|
||||
bot_user_id = bot_info["user_id"]
|
||||
user_info = socket_client.web_client.users_info(user=bot_user_id)
|
||||
@@ -434,23 +426,9 @@ class SlackbotHandler:
|
||||
logger.info(
|
||||
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
|
||||
)
|
||||
except SlackApiError as e:
|
||||
# Only error out if we get a not_authed error
|
||||
if "not_authed" in str(e):
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.error(
|
||||
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
|
||||
"Error: {e}"
|
||||
)
|
||||
return
|
||||
# Log other Slack API errors but continue
|
||||
logger.error(
|
||||
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log other exceptions but continue
|
||||
logger.error(
|
||||
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
logger.warning(
|
||||
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
|
||||
# Append the event handler
|
||||
|
||||
@@ -33,12 +33,6 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -47,8 +41,6 @@ class RedisConnectorDelete:
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@@ -85,20 +77,6 @@ class RedisConnectorDelete:
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
@@ -163,7 +141,6 @@ class RedisConnectorDelete:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@@ -176,9 +153,6 @@ class RedisConnectorDelete:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -93,7 +93,10 @@ class RedisConnectorIndex:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexPayload | None:
|
||||
@@ -103,7 +106,9 @@ class RedisConnectorIndex:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
@@ -118,7 +123,10 @@ class RedisConnectorIndex:
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
|
||||
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
@@ -138,7 +146,10 @@ class RedisConnectorIndex:
|
||||
|
||||
def watchdog_signaled(self) -> bool:
|
||||
"""Check the state of the watchdog."""
|
||||
return bool(self.redis.exists(self.watchdog_key))
|
||||
if self.redis.exists(self.watchdog_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -149,7 +160,10 @@ class RedisConnectorIndex:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_connector_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -166,7 +180,10 @@ class RedisConnectorIndex:
|
||||
return False
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
return bool(self.redis.exists(self.generator_lock_key))
|
||||
if self.redis.exists(self.generator_lock_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
|
||||
@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
|
||||
)
|
||||
is_editable_for_current_user = editable_cc_pair is not None
|
||||
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
document_count_info_list = list(
|
||||
get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
],
|
||||
cc_pair_identifiers=[cc_pair_identifier],
|
||||
)
|
||||
)
|
||||
documents_indexed = (
|
||||
|
||||
@@ -72,31 +72,25 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector import update_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import delete_service_account_credentials
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_latest_index_attempts
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_parallel
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
@@ -125,8 +119,8 @@ from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -584,8 +578,6 @@ def get_connector_status(
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
|
||||
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
|
||||
@@ -640,35 +632,23 @@ def get_connector_indexing_status(
|
||||
# Additional checks are done to make sure the connector and credential still exist.
|
||||
# TODO: make this one query ... possibly eager load or wrap in a read transaction
|
||||
# to avoid the complexity of trying to error check throughout the function
|
||||
|
||||
# see https://stackoverflow.com/questions/75758327/
|
||||
# sqlalchemy-method-connection-for-bind-is-already-in-progress
|
||||
# for why we can't pass in the current db_session to these functions
|
||||
(
|
||||
cc_pairs,
|
||||
latest_index_attempts,
|
||||
latest_finished_index_attempts,
|
||||
) = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
# Gets the connector/credential pairs for the user
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, get_editable, None, True, True, True),
|
||||
),
|
||||
(
|
||||
# Gets the most recent index attempt for each connector/credential pair
|
||||
get_latest_index_attempts_parallel,
|
||||
(secondary_index, True, False),
|
||||
),
|
||||
(
|
||||
# Gets the most recent FINISHED index attempt for each connector/credential pair
|
||||
get_latest_index_attempts_parallel,
|
||||
(secondary_index, True, True),
|
||||
),
|
||||
]
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
|
||||
cc_pair_identifiers = [
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
|
||||
latest_index_attempts = get_latest_index_attempts(
|
||||
secondary_index=secondary_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
|
||||
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
|
||||
|
||||
cc_pair_to_latest_index_attempt = {
|
||||
(
|
||||
@@ -678,60 +658,31 @@ def get_connector_indexing_status(
|
||||
for index_attempt in latest_index_attempts
|
||||
}
|
||||
|
||||
cc_pair_to_latest_finished_index_attempt = {
|
||||
(
|
||||
index_attempt.connector_credential_pair.connector_id,
|
||||
index_attempt.connector_credential_pair.credential_id,
|
||||
): index_attempt
|
||||
for index_attempt in latest_finished_index_attempts
|
||||
}
|
||||
|
||||
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
get_document_counts_for_cc_pairs_parallel,
|
||||
(
|
||||
[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
],
|
||||
),
|
||||
),
|
||||
(
|
||||
get_cc_pair_groups_for_ids_parallel,
|
||||
([cc_pair.id for cc_pair in cc_pairs],),
|
||||
),
|
||||
]
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
|
||||
group_cc_pair_relationships = cast(
|
||||
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
|
||||
)
|
||||
|
||||
cc_pair_to_document_cnt = {
|
||||
(connector_id, credential_id): cnt
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
|
||||
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
|
||||
db_session=db_session,
|
||||
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
|
||||
)
|
||||
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
|
||||
for relationship in group_cc_pair_relationships:
|
||||
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
|
||||
relationship.user_group_id
|
||||
)
|
||||
|
||||
connector_to_cc_pair_ids: dict[int, list[int]] = {}
|
||||
for cc_pair in cc_pairs:
|
||||
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
|
||||
search_settings: SearchSettings | None = None
|
||||
if not secondary_index:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
get_search_settings = (
|
||||
get_secondary_search_settings
|
||||
if secondary_index
|
||||
else get_current_search_settings
|
||||
)
|
||||
search_settings = get_search_settings(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
@@ -754,8 +705,11 @@ def get_connector_indexing_status(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
|
||||
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair.id,
|
||||
secondary_index=secondary_index,
|
||||
only_finished=True,
|
||||
)
|
||||
|
||||
indexing_statuses.append(
|
||||
@@ -764,9 +718,7 @@ def get_connector_indexing_status(
|
||||
name=cc_pair.name,
|
||||
in_progress=in_progress,
|
||||
cc_pair_status=cc_pair.status,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
connector, connector_to_cc_pair_ids.get(connector.id, [])
|
||||
),
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
access_type=cc_pair.access_type,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
|
||||
@@ -83,9 +83,7 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
source: DocumentSource
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(
|
||||
cls, connector: Connector, credential_ids: list[int] | None = None
|
||||
) -> "ConnectorSnapshot":
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
@@ -94,10 +92,9 @@ class ConnectorSnapshot(ConnectorBase):
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=(
|
||||
credential_ids
|
||||
or [association.credential.id for association in connector.credentials]
|
||||
),
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
indexing_start=connector.indexing_start,
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
@@ -47,7 +44,6 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
@@ -69,13 +65,10 @@ from onyx.secondary_llm_flows.chat_session_naming import (
|
||||
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
||||
from onyx.server.query_and_chat.models import ChatMessageIdentifier
|
||||
from onyx.server.query_and_chat.models import ChatRenameRequest
|
||||
from onyx.server.query_and_chat.models import ChatSearchResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetailResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionGroup
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.query_and_chat.models import ChatSessionSummary
|
||||
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatSessionID
|
||||
@@ -801,84 +794,3 @@ def fetch_chat_file(
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
return StreamingResponse(file_io, media_type=media_type)
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_chats(
|
||||
query: str | None = Query(None),
|
||||
page: int = Query(1),
|
||||
page_size: int = Query(10),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSearchResponse:
|
||||
"""
|
||||
Search for chat sessions based on the provided query.
|
||||
If no query is provided, returns recent chat sessions.
|
||||
"""
|
||||
|
||||
# Use the enhanced database function for chat search
|
||||
chat_sessions, has_more = search_chat_sessions(
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
include_deleted=False,
|
||||
include_onyxbot_flows=False,
|
||||
)
|
||||
|
||||
# Group chat sessions by time period
|
||||
today = datetime.datetime.now().date()
|
||||
yesterday = today - timedelta(days=1)
|
||||
this_week = today - timedelta(days=7)
|
||||
this_month = today - timedelta(days=30)
|
||||
|
||||
today_chats: list[ChatSessionSummary] = []
|
||||
yesterday_chats: list[ChatSessionSummary] = []
|
||||
this_week_chats: list[ChatSessionSummary] = []
|
||||
this_month_chats: list[ChatSessionSummary] = []
|
||||
older_chats: list[ChatSessionSummary] = []
|
||||
|
||||
for session in chat_sessions:
|
||||
session_date = session.time_created.date()
|
||||
|
||||
chat_summary = ChatSessionSummary(
|
||||
id=session.id,
|
||||
name=session.description,
|
||||
persona_id=session.persona_id,
|
||||
time_created=session.time_created,
|
||||
shared_status=session.shared_status,
|
||||
folder_id=session.folder_id,
|
||||
current_alternate_model=session.current_alternate_model,
|
||||
current_temperature_override=session.temperature_override,
|
||||
)
|
||||
|
||||
if session_date == today:
|
||||
today_chats.append(chat_summary)
|
||||
elif session_date == yesterday:
|
||||
yesterday_chats.append(chat_summary)
|
||||
elif session_date > this_week:
|
||||
this_week_chats.append(chat_summary)
|
||||
elif session_date > this_month:
|
||||
this_month_chats.append(chat_summary)
|
||||
else:
|
||||
older_chats.append(chat_summary)
|
||||
|
||||
# Create groups
|
||||
groups = []
|
||||
if today_chats:
|
||||
groups.append(ChatSessionGroup(title="Today", chats=today_chats))
|
||||
if yesterday_chats:
|
||||
groups.append(ChatSessionGroup(title="Yesterday", chats=yesterday_chats))
|
||||
if this_week_chats:
|
||||
groups.append(ChatSessionGroup(title="This Week", chats=this_week_chats))
|
||||
if this_month_chats:
|
||||
groups.append(ChatSessionGroup(title="This Month", chats=this_month_chats))
|
||||
if older_chats:
|
||||
groups.append(ChatSessionGroup(title="Older", chats=older_chats))
|
||||
|
||||
return ChatSearchResponse(
|
||||
groups=groups,
|
||||
has_more=has_more,
|
||||
next_page=page + 1 if has_more else None,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -283,35 +282,3 @@ class AdminSearchRequest(BaseModel):
|
||||
|
||||
class AdminSearchResponse(BaseModel):
|
||||
documents: list[SearchDoc]
|
||||
|
||||
|
||||
class ChatSessionSummary(BaseModel):
|
||||
id: UUID
|
||||
name: str | None = None
|
||||
persona_id: int | None = None
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
current_temperature_override: float | None = None
|
||||
|
||||
|
||||
class ChatSessionGroup(BaseModel):
|
||||
title: str
|
||||
chats: list[ChatSessionSummary]
|
||||
|
||||
|
||||
class ChatSearchResponse(BaseModel):
|
||||
groups: list[ChatSessionGroup]
|
||||
has_more: bool
|
||||
next_page: int | None = None
|
||||
|
||||
|
||||
class ChatSearchRequest(BaseModel):
|
||||
query: str | None = None
|
||||
page: int = 1
|
||||
page_size: int = 10
|
||||
|
||||
|
||||
class CreateChatResponse(BaseModel):
|
||||
chat_session_id: str
|
||||
|
||||
@@ -17,6 +17,6 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id is None:
|
||||
if tenant_id is None and MULTI_TENANT:
|
||||
raise RuntimeError("Tenant ID is not set. This should never happen.")
|
||||
return tenant_id
|
||||
|
||||
@@ -68,28 +68,6 @@ const nextConfig = {
|
||||
},
|
||||
];
|
||||
},
|
||||
async rewrites() {
|
||||
return [
|
||||
{
|
||||
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/docs/:path*`,
|
||||
},
|
||||
{
|
||||
source: "/api/docs", // if you also need the exact /api/docs
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/docs`,
|
||||
},
|
||||
{
|
||||
source: "/openapi.json",
|
||||
destination: `${
|
||||
process.env.INTERNAL_URL || "http://localhost:8080"
|
||||
}/openapi.json`,
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
// Sentry configuration for error monitoring:
|
||||
|
||||
@@ -142,7 +142,6 @@ import {
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { MessageChannel } from "node:worker_threads";
|
||||
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -871,7 +870,6 @@ export function ChatPage({
|
||||
}, [liveAssistant]);
|
||||
|
||||
const filterManager = useFilters();
|
||||
const [isChatSearchModalOpen, setIsChatSearchModalOpen] = useState(false);
|
||||
|
||||
const [currentFeedback, setCurrentFeedback] = useState<
|
||||
[FeedbackType, number] | null
|
||||
@@ -2331,11 +2329,6 @@ export function ChatPage({
|
||||
/>
|
||||
)}
|
||||
|
||||
<ChatSearchModal
|
||||
open={isChatSearchModalOpen}
|
||||
onCloseModal={() => setIsChatSearchModalOpen(false)}
|
||||
/>
|
||||
|
||||
{retrievalEnabled && documentSidebarVisible && settings?.isMobile && (
|
||||
<div className="md:hidden">
|
||||
<Modal
|
||||
@@ -2443,9 +2436,6 @@ export function ChatPage({
|
||||
>
|
||||
<div className="w-full relative">
|
||||
<HistorySidebar
|
||||
toggleChatSessionSearchModal={() =>
|
||||
setIsChatSearchModalOpen((open) => !open)
|
||||
}
|
||||
liveAssistant={liveAssistant}
|
||||
setShowAssistantsModal={setShowAssistantsModal}
|
||||
explicitlyUntoggle={explicitlyUntoggle}
|
||||
@@ -2462,7 +2452,6 @@ export function ChatPage({
|
||||
showDeleteAllModal={() => setShowDeleteAllModal(true)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div
|
||||
className={`
|
||||
flex-none
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import React from "react";
|
||||
import { ChatSearchItem } from "./ChatSearchItem";
|
||||
import { ChatSessionSummary } from "./interfaces";
|
||||
|
||||
interface ChatSearchGroupProps {
|
||||
title: string;
|
||||
chats: ChatSessionSummary[];
|
||||
onSelectChat: (id: string) => void;
|
||||
}
|
||||
|
||||
export function ChatSearchGroup({
|
||||
title,
|
||||
chats,
|
||||
onSelectChat,
|
||||
}: ChatSearchGroupProps) {
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
|
||||
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
|
||||
{title}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ol>
|
||||
{chats.map((chat) => (
|
||||
<ChatSearchItem key={chat.id} chat={chat} onSelect={onSelectChat} />
|
||||
))}
|
||||
</ol>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
import React from "react";
|
||||
import { MessageSquare } from "lucide-react";
|
||||
import { ChatSessionSummary } from "./interfaces";
|
||||
|
||||
interface ChatSearchItemProps {
|
||||
chat: ChatSessionSummary;
|
||||
onSelect: (id: string) => void;
|
||||
}
|
||||
|
||||
export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
|
||||
return (
|
||||
<li>
|
||||
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
|
||||
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
|
||||
<div className="flex items-center">
|
||||
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
|
||||
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
|
||||
<div className="text-sm dark:text-neutral-200">
|
||||
{chat.name || "Untitled Chat"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">
|
||||
{new Date(chat.time_created).toLocaleDateString()}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
);
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
import React, { useRef } from "react";
|
||||
import { Dialog, DialogContent } from "@/components/ui/dialog";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { ChatSearchGroup } from "./ChatSearchGroup";
|
||||
import { NewChatButton } from "./NewChatButton";
|
||||
import { useChatSearch } from "./hooks/useChatSearch";
|
||||
import { LoadingSpinner } from "./LoadingSpinner";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { SearchInput } from "./components/SearchInput";
|
||||
import { ChatSearchSkeletonList } from "./components/ChatSearchSkeleton";
|
||||
import { useIntersectionObserver } from "./hooks/useIntersectionObserver";
|
||||
|
||||
interface ChatSearchModalProps {
|
||||
open: boolean;
|
||||
onCloseModal: () => void;
|
||||
}
|
||||
|
||||
export function ChatSearchModal({ open, onCloseModal }: ChatSearchModalProps) {
|
||||
const {
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
chatGroups,
|
||||
isLoading,
|
||||
isSearching,
|
||||
hasMore,
|
||||
fetchMoreChats,
|
||||
} = useChatSearch();
|
||||
|
||||
const onClose = () => {
|
||||
setSearchQuery("");
|
||||
onCloseModal();
|
||||
};
|
||||
|
||||
const router = useRouter();
|
||||
const scrollAreaRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const { targetRef } = useIntersectionObserver({
|
||||
root: scrollAreaRef.current,
|
||||
onIntersect: fetchMoreChats,
|
||||
enabled: open && hasMore && !isLoading,
|
||||
});
|
||||
|
||||
const handleChatSelect = (chatId: string) => {
|
||||
router.push(`/chat?chatId=${chatId}`);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleNewChat = async () => {
|
||||
try {
|
||||
onClose();
|
||||
router.push(`/chat`);
|
||||
} catch (error) {
|
||||
console.error("Error creating new chat:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={(open) => !open && onClose()}>
|
||||
<DialogContent
|
||||
hideCloseIcon
|
||||
className="!rounded-xl overflow-hidden p-0 w-full max-w-2xl"
|
||||
backgroundColor="bg-neutral-950/20 shadow-xl"
|
||||
>
|
||||
<div className="w-full flex flex-col bg-white dark:bg-neutral-800 h-[80vh] max-h-[600px]">
|
||||
<div className="sticky top-0 z-20 px-6 py-3 w-full flex items-center justify-between bg-white dark:bg-neutral-800 border-b border-neutral-200 dark:border-neutral-700">
|
||||
<SearchInput
|
||||
searchQuery={searchQuery}
|
||||
setSearchQuery={setSearchQuery}
|
||||
isSearching={isSearching}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<ScrollArea
|
||||
className="flex-grow bg-white relative dark:bg-neutral-800"
|
||||
ref={scrollAreaRef}
|
||||
type="auto"
|
||||
>
|
||||
<div className="px-4 py-2">
|
||||
<NewChatButton onClick={handleNewChat} />
|
||||
|
||||
{isSearching ? (
|
||||
<ChatSearchSkeletonList />
|
||||
) : isLoading && chatGroups.length === 0 ? (
|
||||
<div className="py-8">
|
||||
<LoadingSpinner size="large" className="mx-auto" />
|
||||
</div>
|
||||
) : chatGroups.length > 0 ? (
|
||||
<>
|
||||
{chatGroups.map((group, groupIndex) => (
|
||||
<ChatSearchGroup
|
||||
key={groupIndex}
|
||||
title={group.title}
|
||||
chats={group.chats}
|
||||
onSelectChat={handleChatSelect}
|
||||
/>
|
||||
))}
|
||||
|
||||
<div ref={targetRef} className="py-4">
|
||||
{isLoading && hasMore && (
|
||||
<LoadingSpinner className="mx-auto" />
|
||||
)}
|
||||
{!hasMore && chatGroups.length > 0 && (
|
||||
<div className="text-center text-xs text-neutral-500 dark:text-neutral-400 py-2">
|
||||
No more chats to load
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
!isLoading && (
|
||||
<div className="px-4 py-3 text-sm text-neutral-500 dark:text-neutral-400">
|
||||
No chats found
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import React from "react";
|
||||
|
||||
interface LoadingSpinnerProps {
|
||||
size?: "small" | "medium" | "large";
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function LoadingSpinner({
|
||||
size = "medium",
|
||||
className = "",
|
||||
}: LoadingSpinnerProps) {
|
||||
const sizeClasses = {
|
||||
small: "h-4 w-4 border-2",
|
||||
medium: "h-6 w-6 border-2",
|
||||
large: "h-8 w-8 border-3",
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`flex justify-center items-center ${className}`}>
|
||||
<div
|
||||
className={`${sizeClasses[size]} animate-spin rounded-full border-solid border-gray-300 border-t-gray-600 dark:border-gray-600 dark:border-t-gray-300`}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
import React from "react";
|
||||
import { PlusCircle } from "lucide-react";
|
||||
import { NewChatIcon } from "@/components/icons/icons";
|
||||
|
||||
interface NewChatButtonProps {
|
||||
onClick: () => void;
|
||||
}
|
||||
|
||||
export function NewChatButton({ onClick }: NewChatButtonProps) {
|
||||
return (
|
||||
<div className="mb-2">
|
||||
<div className="cursor-pointer" onClick={onClick}>
|
||||
<div className="group relative flex items-center rounded-lg px-4 py-3 hover:bg-neutral-100 dark:bg-neutral-800 dark:hover:bg-neutral-700">
|
||||
<NewChatIcon className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
|
||||
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
|
||||
<div className="text-sm dark:text-neutral-200">New Chat</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import React from "react";
|
||||
|
||||
export function ChatSearchItemSkeleton() {
|
||||
return (
|
||||
<div className="animate-pulse px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-700 rounded-lg">
|
||||
<div className="flex items-center">
|
||||
<div className="h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-700"></div>
|
||||
<div className="ml-4 flex-1">
|
||||
<div className="h-2 my-1 w-3/4 bg-neutral-200 dark:bg-neutral-700 rounded"></div>
|
||||
<div className="mt-2 h-3 w-1/2 bg-neutral-200 dark:bg-neutral-700 rounded"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ChatSearchSkeletonList() {
|
||||
return (
|
||||
<div>
|
||||
{[...Array(5)].map((_, index) => (
|
||||
<ChatSearchItemSkeleton key={index} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
import React from "react";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { XIcon } from "lucide-react";
|
||||
import { LoadingSpinner } from "../LoadingSpinner";
|
||||
|
||||
interface SearchInputProps {
|
||||
searchQuery: string;
|
||||
setSearchQuery: (query: string) => void;
|
||||
isSearching: boolean;
|
||||
}
|
||||
|
||||
export function SearchInput({
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
isSearching,
|
||||
}: SearchInputProps) {
|
||||
return (
|
||||
<div className="relative w-full">
|
||||
<div className="flex items-center">
|
||||
<Input
|
||||
removeFocusRing
|
||||
className="w-full !focus-visible:ring-offset-0 !focus-visible:ring-none !focus-visible:ring-0 hover:focus-none border-none bg-transparent placeholder:text-neutral-400 focus:border-transparent focus:outline-none focus:ring-0 dark:placeholder:text-neutral-500 dark:text-neutral-200"
|
||||
placeholder="Search chats..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
/>
|
||||
{searchQuery &&
|
||||
(isSearching ? (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2">
|
||||
<LoadingSpinner size="small" />
|
||||
</div>
|
||||
) : (
|
||||
<XIcon
|
||||
size={16}
|
||||
className="absolute right-2 top-1/2 -translate-y-1/2 cursor-pointer"
|
||||
onClick={() => setSearchQuery("")}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
import { useState, useEffect, useCallback, useRef } from "react";
|
||||
import { fetchChatSessions } from "../utils";
|
||||
import { ChatSessionGroup, ChatSessionSummary } from "../interfaces";
|
||||
|
||||
interface UseChatSearchOptions {
|
||||
pageSize?: number;
|
||||
}
|
||||
|
||||
interface UseChatSearchResult {
|
||||
searchQuery: string;
|
||||
setSearchQuery: (query: string) => void;
|
||||
chatGroups: ChatSessionGroup[];
|
||||
isLoading: boolean;
|
||||
isSearching: boolean;
|
||||
hasMore: boolean;
|
||||
fetchMoreChats: () => Promise<void>;
|
||||
refreshChats: () => Promise<void>;
|
||||
}
|
||||
|
||||
export function useChatSearch(
|
||||
options: UseChatSearchOptions = {}
|
||||
): UseChatSearchResult {
|
||||
const { pageSize = 10 } = options;
|
||||
const [searchQuery, setSearchQueryInternal] = useState("");
|
||||
const [chatGroups, setChatGroups] = useState<ChatSessionGroup[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isSearching, setIsSearching] = useState(false);
|
||||
const [hasMore, setHasMore] = useState(true);
|
||||
const [debouncedIsSearching, setDebouncedIsSearching] = useState(false);
|
||||
|
||||
const [page, setPage] = useState(1);
|
||||
const searchTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const currentAbortController = useRef<AbortController | null>(null);
|
||||
const activeSearchIdRef = useRef<number>(0); // Add a unique ID for each search
|
||||
const PAGE_SIZE = pageSize;
|
||||
|
||||
useEffect(() => {
|
||||
// Only set a timeout if we're not already in the desired state
|
||||
if (!isSearching) {
|
||||
const timeout = setTimeout(() => {
|
||||
setDebouncedIsSearching(isSearching);
|
||||
}, 300);
|
||||
|
||||
// Keep track of the timeout reference to clear it on cleanup
|
||||
const timeoutRef = timeout;
|
||||
|
||||
return () => clearTimeout(timeoutRef);
|
||||
} else {
|
||||
setDebouncedIsSearching(isSearching);
|
||||
}
|
||||
}, [isSearching, debouncedIsSearching]);
|
||||
|
||||
// Helper function to merge groups properly
|
||||
const mergeGroups = useCallback(
|
||||
(
|
||||
existingGroups: ChatSessionGroup[],
|
||||
newGroups: ChatSessionGroup[]
|
||||
): ChatSessionGroup[] => {
|
||||
const mergedGroups: Record<string, ChatSessionSummary[]> = {};
|
||||
|
||||
// Initialize with existing groups
|
||||
existingGroups.forEach((group) => {
|
||||
mergedGroups[group.title] = [
|
||||
...(mergedGroups[group.title] || []),
|
||||
...group.chats,
|
||||
];
|
||||
});
|
||||
|
||||
// Merge in new groups
|
||||
newGroups.forEach((group) => {
|
||||
mergedGroups[group.title] = [
|
||||
...(mergedGroups[group.title] || []),
|
||||
...group.chats,
|
||||
];
|
||||
});
|
||||
|
||||
// Convert back to array format
|
||||
return Object.entries(mergedGroups)
|
||||
.map(([title, chats]) => ({ title, chats }))
|
||||
.sort((a, b) => {
|
||||
// Custom sort order for time periods
|
||||
const order = [
|
||||
"Today",
|
||||
"Yesterday",
|
||||
"This Week",
|
||||
"This Month",
|
||||
"Older",
|
||||
];
|
||||
return order.indexOf(a.title) - order.indexOf(b.title);
|
||||
});
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const fetchInitialChats = useCallback(
|
||||
async (query: string, searchId: number, signal?: AbortSignal) => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setPage(1);
|
||||
|
||||
const response = await fetchChatSessions({
|
||||
query,
|
||||
page: 1,
|
||||
page_size: PAGE_SIZE,
|
||||
signal,
|
||||
});
|
||||
|
||||
// Only update state if this is still the active search
|
||||
if (activeSearchIdRef.current === searchId && !signal?.aborted) {
|
||||
setChatGroups(response.groups);
|
||||
setHasMore(response.has_more);
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (
|
||||
error?.name !== "AbortError" &&
|
||||
activeSearchIdRef.current === searchId
|
||||
) {
|
||||
console.error("Error fetching chats:", error);
|
||||
}
|
||||
} finally {
|
||||
// Only update loading state if this is still the active search
|
||||
if (activeSearchIdRef.current === searchId) {
|
||||
setIsLoading(false);
|
||||
setIsSearching(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[PAGE_SIZE]
|
||||
);
|
||||
|
||||
const fetchMoreChats = useCallback(async () => {
|
||||
if (isLoading || !hasMore) return;
|
||||
|
||||
setIsLoading(true);
|
||||
|
||||
if (currentAbortController.current) {
|
||||
currentAbortController.current.abort();
|
||||
}
|
||||
|
||||
const newSearchId = activeSearchIdRef.current + 1;
|
||||
activeSearchIdRef.current = newSearchId;
|
||||
|
||||
const controller = new AbortController();
|
||||
currentAbortController.current = controller;
|
||||
const localSignal = controller.signal;
|
||||
|
||||
try {
|
||||
const nextPage = page + 1;
|
||||
const response = await fetchChatSessions({
|
||||
query: searchQuery,
|
||||
page: nextPage,
|
||||
page_size: PAGE_SIZE,
|
||||
signal: localSignal,
|
||||
});
|
||||
|
||||
if (activeSearchIdRef.current === newSearchId && !localSignal.aborted) {
|
||||
// Use mergeGroups instead of just concatenating
|
||||
setChatGroups((prevGroups) => mergeGroups(prevGroups, response.groups));
|
||||
setHasMore(response.has_more);
|
||||
setPage(nextPage);
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (
|
||||
error?.name !== "AbortError" &&
|
||||
activeSearchIdRef.current === newSearchId
|
||||
) {
|
||||
console.error("Error fetching more chats:", error);
|
||||
}
|
||||
} finally {
|
||||
if (activeSearchIdRef.current === newSearchId) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
}, [isLoading, hasMore, page, searchQuery, PAGE_SIZE, mergeGroups]);
|
||||
|
||||
const setSearchQuery = useCallback(
|
||||
(query: string) => {
|
||||
setSearchQueryInternal(query);
|
||||
|
||||
// Clear any pending timeouts
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
searchTimeoutRef.current = null;
|
||||
}
|
||||
|
||||
// Abort any in-flight requests
|
||||
if (currentAbortController.current) {
|
||||
currentAbortController.current.abort();
|
||||
currentAbortController.current = null;
|
||||
}
|
||||
|
||||
// Create a new search ID
|
||||
const newSearchId = activeSearchIdRef.current + 1;
|
||||
activeSearchIdRef.current = newSearchId;
|
||||
|
||||
if (query.trim()) {
|
||||
setIsSearching(true);
|
||||
|
||||
const controller = new AbortController();
|
||||
currentAbortController.current = controller;
|
||||
|
||||
searchTimeoutRef.current = setTimeout(() => {
|
||||
fetchInitialChats(query, newSearchId, controller.signal);
|
||||
}, 500);
|
||||
} else {
|
||||
// For empty queries, clear search state immediately
|
||||
setIsSearching(false);
|
||||
// Optionally fetch initial unfiltered results
|
||||
fetchInitialChats("", newSearchId);
|
||||
}
|
||||
},
|
||||
[fetchInitialChats]
|
||||
);
|
||||
|
||||
// Initial fetch on mount
|
||||
useEffect(() => {
|
||||
const newSearchId = activeSearchIdRef.current + 1;
|
||||
activeSearchIdRef.current = newSearchId;
|
||||
|
||||
const controller = new AbortController();
|
||||
currentAbortController.current = controller;
|
||||
|
||||
fetchInitialChats(searchQuery, newSearchId, controller.signal);
|
||||
|
||||
return () => {
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
}
|
||||
controller.abort();
|
||||
};
|
||||
}, [fetchInitialChats, searchQuery]);
|
||||
|
||||
return {
|
||||
searchQuery,
|
||||
setSearchQuery,
|
||||
chatGroups,
|
||||
isLoading,
|
||||
isSearching: debouncedIsSearching,
|
||||
hasMore,
|
||||
fetchMoreChats,
|
||||
refreshChats: () => {
|
||||
const newSearchId = activeSearchIdRef.current + 1;
|
||||
activeSearchIdRef.current = newSearchId;
|
||||
|
||||
if (currentAbortController.current) {
|
||||
currentAbortController.current.abort();
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
currentAbortController.current = controller;
|
||||
|
||||
return fetchInitialChats(searchQuery, newSearchId, controller.signal);
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
interface UseIntersectionObserverOptions {
|
||||
root?: Element | null;
|
||||
rootMargin?: string;
|
||||
threshold?: number;
|
||||
onIntersect: () => void;
|
||||
enabled?: boolean;
|
||||
}
|
||||
|
||||
export function useIntersectionObserver({
|
||||
root = null,
|
||||
rootMargin = "0px",
|
||||
threshold = 0.1,
|
||||
onIntersect,
|
||||
enabled = true,
|
||||
}: UseIntersectionObserverOptions) {
|
||||
const targetRef = useRef<HTMLDivElement | null>(null);
|
||||
const observerRef = useRef<IntersectionObserver | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!enabled) return;
|
||||
|
||||
const options = {
|
||||
root,
|
||||
rootMargin,
|
||||
threshold,
|
||||
};
|
||||
|
||||
const observer = new IntersectionObserver((entries) => {
|
||||
const [entry] = entries;
|
||||
if (entry.isIntersecting) {
|
||||
onIntersect();
|
||||
}
|
||||
}, options);
|
||||
|
||||
if (targetRef.current) {
|
||||
observer.observe(targetRef.current);
|
||||
}
|
||||
|
||||
observerRef.current = observer;
|
||||
|
||||
return () => {
|
||||
if (observerRef.current) {
|
||||
observerRef.current.disconnect();
|
||||
}
|
||||
};
|
||||
}, [root, rootMargin, threshold, onIntersect, enabled]);
|
||||
|
||||
return { targetRef };
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
import { ChatSessionSharedStatus } from "../interfaces";
|
||||
|
||||
export interface ChatSessionSummary {
|
||||
id: string;
|
||||
name: string | null;
|
||||
persona_id: number | null;
|
||||
time_created: string;
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
folder_id: number | null;
|
||||
current_alternate_model: string | null;
|
||||
current_temperature_override: number | null;
|
||||
highlights?: string[];
|
||||
}
|
||||
|
||||
export interface ChatSessionGroup {
|
||||
title: string;
|
||||
chats: ChatSessionSummary[];
|
||||
}
|
||||
|
||||
export interface ChatSessionsResponse {
|
||||
sessions: ChatSessionSummary[];
|
||||
}
|
||||
|
||||
export interface ChatSearchResponse {
|
||||
groups: ChatSessionGroup[];
|
||||
has_more: boolean;
|
||||
next_page: number | null;
|
||||
}
|
||||
|
||||
export interface ChatSearchRequest {
|
||||
query?: string;
|
||||
page?: number;
|
||||
page_size?: number;
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
import { ChatSearchRequest, ChatSearchResponse } from "./interfaces";
|
||||
|
||||
const API_BASE_URL = "/api";
|
||||
|
||||
export interface ExtendedChatSearchRequest extends ChatSearchRequest {
|
||||
include_highlights?: boolean;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export async function fetchChatSessions(
|
||||
params: ExtendedChatSearchRequest = {}
|
||||
): Promise<ChatSearchResponse> {
|
||||
const queryParams = new URLSearchParams();
|
||||
|
||||
if (params.query) {
|
||||
queryParams.append("query", params.query);
|
||||
}
|
||||
|
||||
if (params.page) {
|
||||
queryParams.append("page", params.page.toString());
|
||||
}
|
||||
|
||||
if (params.page_size) {
|
||||
queryParams.append("page_size", params.page_size.toString());
|
||||
}
|
||||
|
||||
if (params.include_highlights !== undefined) {
|
||||
queryParams.append(
|
||||
"include_highlights",
|
||||
params.include_highlights.toString()
|
||||
);
|
||||
}
|
||||
|
||||
const queryString = queryParams.toString()
|
||||
? `?${queryParams.toString()}`
|
||||
: "";
|
||||
|
||||
const response = await fetch(`${API_BASE_URL}/chat/search${queryString}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch chat sessions: ${response.statusText}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function createNewChat(): Promise<{ chat_session_id: string }> {
|
||||
const response = await fetch(`${API_BASE_URL}/chat/sessions`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to create new chat: ${response.statusText}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
export async function deleteChat(chatId: string): Promise<void> {
|
||||
const response = await fetch(`${API_BASE_URL}/chat/sessions/${chatId}`, {
|
||||
method: "DELETE",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to delete chat: ${response.statusText}`);
|
||||
}
|
||||
}
|
||||
@@ -66,7 +66,6 @@ interface HistorySidebarProps {
|
||||
explicitlyUntoggle: () => void;
|
||||
showDeleteAllModal?: () => void;
|
||||
setShowAssistantsModal: (show: boolean) => void;
|
||||
toggleChatSessionSearchModal?: () => void;
|
||||
}
|
||||
|
||||
interface SortableAssistantProps {
|
||||
@@ -181,7 +180,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
toggleSidebar,
|
||||
removeToggle,
|
||||
showShareModal,
|
||||
toggleChatSessionSearchModal,
|
||||
showDeleteModal,
|
||||
showDeleteAllModal,
|
||||
},
|
||||
@@ -320,6 +318,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="h-full relative overflow-x-hidden overflow-y-auto">
|
||||
<div className="flex px-4 font-normal text-sm gap-x-2 leading-normal text-text-500/80 dark:text-[#D4D4D4] items-center font-normal leading-normal">
|
||||
Assistants
|
||||
@@ -396,7 +395,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
</div>
|
||||
|
||||
<PagesTab
|
||||
toggleChatSessionSearchModal={toggleChatSessionSearchModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
closeSidebar={removeToggle}
|
||||
|
||||
@@ -17,13 +17,6 @@ import { useState, useCallback, useRef, useContext, useEffect } from "react";
|
||||
import { Caret } from "@/components/icons/icons";
|
||||
import { groupSessionsByDateRange } from "../lib";
|
||||
import React from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
TooltipContent,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { Search } from "lucide-react";
|
||||
import {
|
||||
DndContext,
|
||||
closestCenter,
|
||||
@@ -108,12 +101,10 @@ export function PagesTab({
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
showDeleteAllModal,
|
||||
toggleChatSessionSearchModal,
|
||||
}: {
|
||||
existingChats?: ChatSession[];
|
||||
currentChatId?: string;
|
||||
folders?: Folder[];
|
||||
toggleChatSessionSearchModal?: () => void;
|
||||
closeSidebar?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
@@ -327,28 +318,8 @@ export function PagesTab({
|
||||
<div className="flex flex-col gap-y-2 flex-grow">
|
||||
{popup}
|
||||
<div className="px-4 mt-2 group mr-2 bg-background-sidebar dark:bg-transparent z-20">
|
||||
<div className="flex group justify-between text-sm gap-x-2 text-text-300/80 items-center font-normal leading-normal">
|
||||
<div className="flex justify-between text-sm gap-x-2 text-text-300/80 items-center font-normal leading-normal">
|
||||
<p>Chats</p>
|
||||
|
||||
<TooltipProvider delayDuration={1000}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
className="my-auto mr-auto group-hover:opacity-100 opacity-0 transition duration-200 cursor-pointer gap-x-1 items-center text-black text-xs font-medium leading-normal mobile:hidden"
|
||||
onClick={() => {
|
||||
toggleChatSessionSearchModal?.();
|
||||
}}
|
||||
>
|
||||
<Search
|
||||
className="flex-none text-text-mobile-sidebar"
|
||||
size={12}
|
||||
/>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Search Chats</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
|
||||
<button
|
||||
onClick={handleCreateFolder}
|
||||
className="flex group-hover:opacity-100 opacity-0 transition duration-200 cursor-pointer gap-x-1 items-center text-black text-xs font-medium leading-normal"
|
||||
|
||||
@@ -110,38 +110,37 @@ export default function LogoWithText({
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<div className="flex ml-auto gap-x-4">
|
||||
{showArrow && toggleSidebar && (
|
||||
<TooltipProvider delayDuration={0}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
className="mr-2 my-auto"
|
||||
onClick={() => {
|
||||
toggleSidebar();
|
||||
if (toggled) {
|
||||
explicitlyUntoggle();
|
||||
}
|
||||
}}
|
||||
>
|
||||
{!toggled && !combinedSettings?.isMobile ? (
|
||||
<RightToLineIcon className="mobile:hidden text-sidebar-toggle" />
|
||||
) : (
|
||||
<LeftToLineIcon className="mobile:hidden text-sidebar-toggle" />
|
||||
)}
|
||||
<FiSidebar
|
||||
size={20}
|
||||
className="hidden mobile:block text-text-mobile-sidebar"
|
||||
/>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="!border-none">
|
||||
{toggled ? `Unpin sidebar` : "Pin sidebar"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{showArrow && toggleSidebar && (
|
||||
<TooltipProvider delayDuration={0}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
className="mr-2 my-auto ml-auto"
|
||||
onClick={() => {
|
||||
toggleSidebar();
|
||||
if (toggled) {
|
||||
explicitlyUntoggle();
|
||||
}
|
||||
}}
|
||||
>
|
||||
{!toggled && !combinedSettings?.isMobile ? (
|
||||
<RightToLineIcon className="mobile:hidden text-sidebar-toggle" />
|
||||
) : (
|
||||
<LeftToLineIcon className="mobile:hidden text-sidebar-toggle" />
|
||||
)}
|
||||
<FiSidebar
|
||||
size={20}
|
||||
className="hidden mobile:block text-text-mobile-sidebar"
|
||||
/>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="!border-none">
|
||||
{toggled ? `Unpin sidebar` : "Pin sidebar"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -16,15 +16,12 @@ const DialogClose = DialogPrimitive.Close;
|
||||
|
||||
const DialogOverlay = React.forwardRef<
|
||||
React.ElementRef<typeof DialogPrimitive.Overlay>,
|
||||
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay> & {
|
||||
backgroundColor?: string;
|
||||
}
|
||||
>(({ className, backgroundColor, ...props }, ref) => (
|
||||
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<DialogPrimitive.Overlay
|
||||
ref={ref}
|
||||
className={cn(
|
||||
backgroundColor || "bg-neutral-950/60",
|
||||
"fixed inset-0 z-50 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
|
||||
"fixed inset-0 z-50 bg-neutral-950/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
@@ -36,11 +33,10 @@ const DialogContent = React.forwardRef<
|
||||
React.ElementRef<typeof DialogPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content> & {
|
||||
hideCloseIcon?: boolean;
|
||||
backgroundColor?: string;
|
||||
}
|
||||
>(({ className, children, hideCloseIcon, backgroundColor, ...props }, ref) => (
|
||||
>(({ className, children, hideCloseIcon, ...props }, ref) => (
|
||||
<DialogPortal>
|
||||
<DialogOverlay backgroundColor={backgroundColor} />
|
||||
<DialogOverlay />
|
||||
<DialogPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
|
||||
@@ -2,20 +2,13 @@ import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface InputProps extends React.ComponentProps<"input"> {
|
||||
removeFocusRing?: boolean;
|
||||
}
|
||||
|
||||
const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
({ className, type, removeFocusRing, ...props }, ref) => {
|
||||
const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<"input">>(
|
||||
({ className, type, ...props }, ref) => {
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
className={cn(
|
||||
"flex h-10 w-full rounded-md border border-neutral-200 bg-white px-3 py-2 text-base ring-offset-white file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-neutral-950 placeholder:text-neutral-500",
|
||||
removeFocusRing
|
||||
? ""
|
||||
: "focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm dark:border-neutral-800 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:file:text-neutral-50 dark:placeholder:text-neutral-400 dark:focus-visible:ring-neutral-300",
|
||||
"flex h-10 w-full rounded-md border border-neutral-200 bg-white px-3 py-2 text-base ring-offset-white file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-neutral-950 placeholder:text-neutral-500 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm dark:border-neutral-800 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:file:text-neutral-50 dark:placeholder:text-neutral-400 dark:focus-visible:ring-neutral-300",
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
|
||||
@@ -18,7 +18,7 @@ async function verifyAdminPageNavigation(
|
||||
|
||||
try {
|
||||
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle, {
|
||||
timeout: 5000,
|
||||
timeout: 3000,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(
|
||||
|
||||
Reference in New Issue
Block a user