mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-25 19:55:47 +00:00
Compare commits
13 Commits
user-filte
...
testing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25b38212e9 | ||
|
|
3096b0b2a7 | ||
|
|
342bb9f685 | ||
|
|
b25668c83a | ||
|
|
a72bd31f5d | ||
|
|
896e716d02 | ||
|
|
eec3ce8162 | ||
|
|
2761a837c6 | ||
|
|
da43abe644 | ||
|
|
af953ff8a3 | ||
|
|
6fc52c81ab | ||
|
|
1ad2128b2a | ||
|
|
880c42ad41 |
8
.github/pull_request_template.md
vendored
8
.github/pull_request_template.md
vendored
@@ -1,11 +1,15 @@
|
||||
## Description
|
||||
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] I have included a link to a Linear ticket in my description.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
29
.github/workflows/pr-linear-check.yml
vendored
Normal file
29
.github/workflows/pr-linear-check.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Ensure PR references Linear
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
jobs:
|
||||
linear-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Looking for a checked override: "[x] Override Linear Check"
|
||||
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
|
||||
echo "Override box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Otherwise, fail the run
|
||||
echo "No Linear link or override found in the PR description."
|
||||
exit 1
|
||||
1
.vscode/env_template.txt
vendored
1
.vscode/env_template.txt
vendored
@@ -29,6 +29,7 @@ REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
@@ -17,9 +17,10 @@ Before starting, make sure the Docker Daemon is running.
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
4. CD into web, run "npm i" followed by npm run dev.
|
||||
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
|
||||
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
7. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Add has_been_indexed to DocumentByConnectorCredentialPair
|
||||
|
||||
Revision ID: c7bf5721733e
|
||||
Revises: fec3db967bf7
|
||||
Create Date: 2025-01-13 12:39:05.831693
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c7bf5721733e"
|
||||
down_revision = "027381bce97c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# assume all existing rows have been indexed, no better approach
|
||||
op.add_column(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
|
||||
)
|
||||
op.alter_column(
|
||||
"document_by_connector_credential_pair",
|
||||
"has_been_indexed",
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Add index to optimize get_document_counts_for_cc_pairs query pattern
|
||||
op.create_index(
|
||||
"idx_document_cc_pair_counts",
|
||||
"document_by_connector_credential_pair",
|
||||
["connector_id", "credential_id", "has_been_indexed"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the index first before removing the column
|
||||
op.drop_index(
|
||||
"idx_document_cc_pair_counts",
|
||||
table_name="document_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")
|
||||
@@ -1,6 +1,9 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
@@ -8,7 +11,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
@@ -20,5 +23,9 @@ ee_tasks_to_schedule = [
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return base_cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
|
||||
@@ -1,27 +1,135 @@
|
||||
import datetime
|
||||
from typing import Literal
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy import BinaryExpression
|
||||
from sqlalchemy import ColumnElement
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import distinct
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import case
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.sql import select
|
||||
from sqlalchemy.sql.expression import literal
|
||||
from sqlalchemy.sql.expression import UnaryExpression
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
SortByOptions = Literal["time_sent"]
|
||||
|
||||
def _build_filter_conditions(
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
feedback_filter: QAFeedbackType | None,
|
||||
) -> list[ColumnElement]:
|
||||
"""
|
||||
Helper function to build all filter conditions for chat sessions.
|
||||
Filters by start and end time, feedback type, and any sessions without messages.
|
||||
start_time: Date from which to filter
|
||||
end_time: Date to which to filter
|
||||
feedback_filter: Feedback type to filter by
|
||||
Returns: List of filter conditions
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
if start_time is not None:
|
||||
conditions.append(ChatSession.time_created >= start_time)
|
||||
if end_time is not None:
|
||||
conditions.append(ChatSession.time_created <= end_time)
|
||||
|
||||
if feedback_filter is not None:
|
||||
feedback_subq = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatMessageFeedback)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.having(
|
||||
case(
|
||||
(
|
||||
case(
|
||||
{literal(feedback_filter == QAFeedbackType.LIKE): True},
|
||||
else_=False,
|
||||
),
|
||||
func.bool_and(ChatMessageFeedback.is_positive),
|
||||
),
|
||||
(
|
||||
case(
|
||||
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
|
||||
else_=False,
|
||||
),
|
||||
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
|
||||
),
|
||||
else_=func.bool_or(ChatMessageFeedback.is_positive)
|
||||
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
|
||||
)
|
||||
)
|
||||
)
|
||||
conditions.append(ChatSession.id.in_(feedback_subq))
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
def get_total_filtered_chat_sessions_count(
|
||||
db_session: Session,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
feedback_filter: QAFeedbackType | None,
|
||||
) -> int:
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
stmt = (
|
||||
select(func.count(distinct(ChatSession.id)))
|
||||
.select_from(ChatSession)
|
||||
.filter(*conditions)
|
||||
)
|
||||
return db_session.scalar(stmt) or 0
|
||||
|
||||
|
||||
def get_page_of_chat_sessions(
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
db_session: Session,
|
||||
page_num: int,
|
||||
page_size: int,
|
||||
feedback_filter: QAFeedbackType | None = None,
|
||||
) -> Sequence[ChatSession]:
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
|
||||
subquery = (
|
||||
select(ChatSession.id, ChatSession.time_created)
|
||||
.filter(*conditions)
|
||||
.order_by(ChatSession.id, desc(ChatSession.time_created))
|
||||
.distinct(ChatSession.id)
|
||||
.limit(page_size)
|
||||
.offset(page_num * page_size)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.join(subquery, ChatSession.id == subquery.c.id)
|
||||
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.options(
|
||||
joinedload(ChatSession.user),
|
||||
joinedload(ChatSession.persona),
|
||||
contains_eager(ChatSession.messages).joinedload(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_chat_sessions_eagerly_by_time(
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
db_session: Session,
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime.datetime | None = None,
|
||||
initial_time: datetime | None = None,
|
||||
) -> list[ChatSession]:
|
||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||
|
||||
@@ -120,9 +120,12 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_emails,
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=public,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,16 +1,127 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import AdminService
|
||||
from onyx.connectors.google_utils.resources import get_admin_service
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_drive_members(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
) -> dict[str, tuple[set[str], set[str]]]:
|
||||
"""
|
||||
This builds a map of drive ids to their members (group and user emails).
|
||||
E.g. {
|
||||
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
|
||||
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
|
||||
}
|
||||
"""
|
||||
drive_ids = google_drive_connector.get_all_drive_ids()
|
||||
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
|
||||
drive_service = get_drive_service(
|
||||
google_drive_connector.creds,
|
||||
google_drive_connector.primary_admin_email,
|
||||
)
|
||||
|
||||
for drive_id in drive_ids:
|
||||
group_emails: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type)",
|
||||
supportsAllDrives=True,
|
||||
):
|
||||
if permission["type"] == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission["type"] == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
|
||||
return drive_id_to_members_map
|
||||
|
||||
|
||||
def _get_all_groups(
|
||||
admin_service: AdminService,
|
||||
google_domain: str,
|
||||
) -> set[str]:
|
||||
"""
|
||||
This gets all the group emails.
|
||||
"""
|
||||
group_emails: set[str] = set()
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_domain,
|
||||
fields="groups(email)",
|
||||
):
|
||||
group_emails.add(group["email"])
|
||||
return group_emails
|
||||
|
||||
|
||||
def _map_group_email_to_member_emails(
|
||||
admin_service: AdminService,
|
||||
group_emails: set[str],
|
||||
) -> dict[str, set[str]]:
|
||||
"""
|
||||
This maps group emails to their member emails.
|
||||
"""
|
||||
group_to_member_map: dict[str, set[str]] = {}
|
||||
for group_email in group_emails:
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
):
|
||||
group_member_emails.add(member["email"])
|
||||
|
||||
group_to_member_map[group_email] = group_member_emails
|
||||
return group_to_member_map
|
||||
|
||||
|
||||
def _build_onyx_groups(
|
||||
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
|
||||
group_email_to_member_emails_map: dict[str, set[str]],
|
||||
) -> list[ExternalUserGroup]:
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
|
||||
# Convert all drive member definitions to onyx groups
|
||||
# This is because having drive level access means you have
|
||||
# irrevocable access to all the files in the drive.
|
||||
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
|
||||
all_member_emails: set[str] = user_emails
|
||||
for group_email in group_emails:
|
||||
all_member_emails.update(group_email_to_member_emails_map[group_email])
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=drive_id,
|
||||
user_emails=list(all_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
# Convert all group member definitions to onyx groups
|
||||
for group_email, member_emails in group_email_to_member_emails_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -19,34 +130,23 @@ def gdrive_group_sync(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
list_key="groups",
|
||||
domain=google_drive_connector.google_domain,
|
||||
fields="groups(email)",
|
||||
):
|
||||
# The id is the group email
|
||||
group_email = group["email"]
|
||||
# Get all drive members
|
||||
drive_id_to_members_map = _get_drive_members(google_drive_connector)
|
||||
|
||||
# Gather group member emails
|
||||
group_member_emails: list[str] = []
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email)",
|
||||
):
|
||||
group_member_emails.append(member["email"])
|
||||
# Get all group emails
|
||||
all_group_emails = _get_all_groups(
|
||||
admin_service, google_drive_connector.google_domain
|
||||
)
|
||||
|
||||
if not group_member_emails:
|
||||
continue
|
||||
# Map group emails to their members
|
||||
group_email_to_member_emails_map = _map_group_email_to_member_emails(
|
||||
admin_service, all_group_emails
|
||||
)
|
||||
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
# Convert the maps to onyx groups
|
||||
onyx_groups = _build_onyx_groups(
|
||||
drive_id_to_members_map=drive_id_to_members_map,
|
||||
group_email_to_member_emails_map=group_email_to_member_emails_map,
|
||||
)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
from ee.onyx.db.query_history import get_page_of_chat_sessions
|
||||
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from ee.onyx.server.query_history.models import MessageSnapshot
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -23,257 +27,15 @@ from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
"""A subset of the info present in `SearchDoc`"""
|
||||
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
|
||||
|
||||
class MessageSnapshot(BaseModel):
|
||||
message: str
|
||||
message_type: MessageType
|
||||
documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
time_created: datetime
|
||||
|
||||
@classmethod
|
||||
def build(cls, message: ChatMessage) -> "MessageSnapshot":
|
||||
latest_messages_feedback_obj = (
|
||||
message.chat_message_feedbacks[-1]
|
||||
if len(message.chat_message_feedbacks) > 0
|
||||
else None
|
||||
)
|
||||
feedback_type = (
|
||||
(
|
||||
QAFeedbackType.LIKE
|
||||
if latest_messages_feedback_obj.is_positive
|
||||
else QAFeedbackType.DISLIKE
|
||||
)
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
feedback_text = (
|
||||
latest_messages_feedback_obj.feedback_text
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
message=message.message,
|
||||
message_type=message.message_type,
|
||||
documents=[
|
||||
AbridgedSearchDoc(
|
||||
document_id=document.document_id,
|
||||
semantic_identifier=document.semantic_id,
|
||||
link=document.link,
|
||||
)
|
||||
for document in message.search_docs
|
||||
],
|
||||
feedback_type=feedback_type,
|
||||
feedback_text=feedback_text,
|
||||
time_created=message.time_sent,
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionMinimal(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||
flow_type: SessionType
|
||||
conversation_length: int
|
||||
|
||||
|
||||
class ChatSessionSnapshot(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
|
||||
class QuestionAnswerPairSnapshot(BaseModel):
|
||||
chat_session_id: UUID
|
||||
# 1-indexed message number in the chat_session
|
||||
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
|
||||
message_pair_num: int
|
||||
user_message: str
|
||||
ai_response: str
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
@classmethod
|
||||
def from_chat_session_snapshot(
|
||||
cls,
|
||||
chat_session_snapshot: ChatSessionSnapshot,
|
||||
) -> list["QuestionAnswerPairSnapshot"]:
|
||||
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
|
||||
for ind in range(1, len(chat_session_snapshot.messages), 2):
|
||||
message_pairs.append(
|
||||
(
|
||||
chat_session_snapshot.messages[ind - 1],
|
||||
chat_session_snapshot.messages[ind],
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
cls(
|
||||
chat_session_id=chat_session_snapshot.id,
|
||||
message_pair_num=ind + 1,
|
||||
user_message=user_message.message,
|
||||
ai_response=ai_message.message,
|
||||
retrieved_documents=ai_message.documents,
|
||||
feedback_type=ai_message.feedback_type,
|
||||
feedback_text=ai_message.feedback_text,
|
||||
persona_name=chat_session_snapshot.assistant_name,
|
||||
user_email=get_display_email(chat_session_snapshot.user_email),
|
||||
time_created=user_message.time_created,
|
||||
flow_type=chat_session_snapshot.flow_type,
|
||||
)
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"retrieved_documents": "|".join(
|
||||
[
|
||||
doc.link or doc.semantic_identifier
|
||||
for doc in self.retrieved_documents
|
||||
]
|
||||
),
|
||||
"feedback_type": self.feedback_type.value if self.feedback_type else "",
|
||||
"feedback_text": self.feedback_text or "",
|
||||
"persona_name": self.persona_name,
|
||||
"user_email": self.user_email,
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
|
||||
|
||||
def determine_flow_type(chat_session: ChatSession) -> SessionType:
|
||||
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history_minimal(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
feedback_filter: QAFeedbackType | None = None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionMinimal]:
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
|
||||
minimal_sessions = []
|
||||
for chat_session in chat_sessions:
|
||||
if not chat_session.messages:
|
||||
continue
|
||||
|
||||
first_user_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.USER
|
||||
),
|
||||
"",
|
||||
)
|
||||
first_ai_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.ASSISTANT
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
has_positive_feedback = any(
|
||||
feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
)
|
||||
|
||||
has_negative_feedback = any(
|
||||
not feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
)
|
||||
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
|
||||
"mixed"
|
||||
if has_positive_feedback and has_negative_feedback
|
||||
else QAFeedbackType.LIKE
|
||||
if has_positive_feedback
|
||||
else QAFeedbackType.DISLIKE
|
||||
if has_negative_feedback
|
||||
else None
|
||||
)
|
||||
|
||||
if feedback_filter:
|
||||
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
|
||||
continue
|
||||
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
|
||||
continue
|
||||
|
||||
flow_type = determine_flow_type(chat_session)
|
||||
|
||||
minimal_sessions.append(
|
||||
ChatSessionMinimal(
|
||||
id=chat_session.id,
|
||||
user_email=get_display_email(
|
||||
chat_session.user.email if chat_session.user else None
|
||||
),
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
assistant_id=chat_session.persona_id,
|
||||
assistant_name=(
|
||||
chat_session.persona.name if chat_session.persona else None
|
||||
),
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=feedback_type,
|
||||
flow_type=flow_type,
|
||||
conversation_length=len(
|
||||
[
|
||||
m
|
||||
for m in chat_session.messages
|
||||
if m.message_type != MessageType.SYSTEM
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return minimal_sessions
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
start: datetime,
|
||||
@@ -319,7 +81,7 @@ def snapshot_from_chat_session(
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
flow_type = determine_flow_type(chat_session)
|
||||
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
|
||||
|
||||
return ChatSessionSnapshot(
|
||||
id=chat_session.id,
|
||||
@@ -371,22 +133,38 @@ def get_user_chat_sessions(
|
||||
|
||||
@router.get("/admin/chat-session-history")
|
||||
def get_chat_session_history(
|
||||
page_num: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1),
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ChatSessionMinimal]:
|
||||
return fetch_and_process_chat_session_history_minimal(
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
db_session=db_session,
|
||||
start=start
|
||||
or (
|
||||
datetime.now(tz=timezone.utc) - timedelta(days=30)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.now(tz=timezone.utc),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
|
||||
db_session=db_session,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
|
||||
218
backend/ee/onyx/server/query_history/models.py
Normal file
218
backend/ee/onyx/server/query_history/models.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
class AbridgedSearchDoc(BaseModel):
|
||||
"""A subset of the info present in `SearchDoc`"""
|
||||
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
|
||||
|
||||
class MessageSnapshot(BaseModel):
|
||||
id: int
|
||||
message: str
|
||||
message_type: MessageType
|
||||
documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
time_created: datetime
|
||||
|
||||
@classmethod
|
||||
def build(cls, message: ChatMessage) -> "MessageSnapshot":
|
||||
latest_messages_feedback_obj = (
|
||||
message.chat_message_feedbacks[-1]
|
||||
if len(message.chat_message_feedbacks) > 0
|
||||
else None
|
||||
)
|
||||
feedback_type = (
|
||||
(
|
||||
QAFeedbackType.LIKE
|
||||
if latest_messages_feedback_obj.is_positive
|
||||
else QAFeedbackType.DISLIKE
|
||||
)
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
feedback_text = (
|
||||
latest_messages_feedback_obj.feedback_text
|
||||
if latest_messages_feedback_obj
|
||||
else None
|
||||
)
|
||||
return cls(
|
||||
id=message.id,
|
||||
message=message.message,
|
||||
message_type=message.message_type,
|
||||
documents=[
|
||||
AbridgedSearchDoc(
|
||||
document_id=document.document_id,
|
||||
semantic_identifier=document.semantic_id,
|
||||
link=document.link,
|
||||
)
|
||||
for document in message.search_docs
|
||||
],
|
||||
feedback_type=feedback_type,
|
||||
feedback_text=feedback_text,
|
||||
time_created=message.time_sent,
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionMinimal(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | None
|
||||
flow_type: SessionType
|
||||
conversation_length: int
|
||||
|
||||
@classmethod
|
||||
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
|
||||
first_user_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.USER
|
||||
),
|
||||
"",
|
||||
)
|
||||
first_ai_message = next(
|
||||
(
|
||||
message.message
|
||||
for message in chat_session.messages
|
||||
if message.message_type == MessageType.ASSISTANT
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
list_of_message_feedbacks = [
|
||||
feedback.is_positive
|
||||
for message in chat_session.messages
|
||||
for feedback in message.chat_message_feedbacks
|
||||
]
|
||||
session_feedback_type = None
|
||||
if list_of_message_feedbacks:
|
||||
if all(list_of_message_feedbacks):
|
||||
session_feedback_type = QAFeedbackType.LIKE
|
||||
elif not any(list_of_message_feedbacks):
|
||||
session_feedback_type = QAFeedbackType.DISLIKE
|
||||
else:
|
||||
session_feedback_type = QAFeedbackType.MIXED
|
||||
|
||||
return cls(
|
||||
id=chat_session.id,
|
||||
user_email=get_display_email(
|
||||
chat_session.user.email if chat_session.user else None
|
||||
),
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
assistant_id=chat_session.persona_id,
|
||||
assistant_name=(
|
||||
chat_session.persona.name if chat_session.persona else None
|
||||
),
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=session_feedback_type,
|
||||
flow_type=SessionType.SLACK
|
||||
if chat_session.onyxbot_flow
|
||||
else SessionType.CHAT,
|
||||
conversation_length=len(
|
||||
[
|
||||
message
|
||||
for message in chat_session.messages
|
||||
if message.message_type != MessageType.SYSTEM
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChatSessionSnapshot(BaseModel):
|
||||
id: UUID
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
assistant_id: int | None
|
||||
assistant_name: str | None
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
|
||||
class QuestionAnswerPairSnapshot(BaseModel):
|
||||
chat_session_id: UUID
|
||||
# 1-indexed message number in the chat_session
|
||||
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
|
||||
message_pair_num: int
|
||||
user_message: str
|
||||
ai_response: str
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
flow_type: SessionType
|
||||
|
||||
@classmethod
|
||||
def from_chat_session_snapshot(
|
||||
cls,
|
||||
chat_session_snapshot: ChatSessionSnapshot,
|
||||
) -> list["QuestionAnswerPairSnapshot"]:
|
||||
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
|
||||
for ind in range(1, len(chat_session_snapshot.messages), 2):
|
||||
message_pairs.append(
|
||||
(
|
||||
chat_session_snapshot.messages[ind - 1],
|
||||
chat_session_snapshot.messages[ind],
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
cls(
|
||||
chat_session_id=chat_session_snapshot.id,
|
||||
message_pair_num=ind + 1,
|
||||
user_message=user_message.message,
|
||||
ai_response=ai_message.message,
|
||||
retrieved_documents=ai_message.documents,
|
||||
feedback_type=ai_message.feedback_type,
|
||||
feedback_text=ai_message.feedback_text,
|
||||
persona_name=chat_session_snapshot.assistant_name,
|
||||
user_email=get_display_email(chat_session_snapshot.user_email),
|
||||
time_created=user_message.time_created,
|
||||
flow_type=chat_session_snapshot.flow_type,
|
||||
)
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
"user_message": self.user_message,
|
||||
"ai_response": self.ai_response,
|
||||
"retrieved_documents": "|".join(
|
||||
[
|
||||
doc.link or doc.semantic_identifier
|
||||
for doc in self.retrieved_documents
|
||||
]
|
||||
),
|
||||
"feedback_type": self.feedback_type.value if self.feedback_type else "",
|
||||
"feedback_text": self.feedback_text or "",
|
||||
"persona_name": self.persona_name,
|
||||
"user_email": self.user_email,
|
||||
"time_created": str(self.time_created),
|
||||
"flow_type": self.flow_type,
|
||||
}
|
||||
@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
|
||||
llms: list[LLMProviderUpsertRequest] | None = None
|
||||
admin_user_emails: list[str] | None = None
|
||||
seeded_logo_path: str | None = None
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
personas: list[PersonaUpsertRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
enterprise_settings: EnterpriseSettings | None = None
|
||||
|
||||
@@ -128,7 +128,7 @@ def _seed_llms(
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
|
||||
@@ -20,6 +20,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
@@ -100,6 +101,10 @@ def on_task_postrun(
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
if task.name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# this is a cloud / all tenant task ... no postrun is needed
|
||||
return
|
||||
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -7,12 +8,14 @@ from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -28,7 +31,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._update_tenant_tasks()
|
||||
self._try_updating_schedule()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
@@ -44,105 +47,154 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
self._update_tenant_tasks()
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
|
||||
self._last_reload = now
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting task update process")
|
||||
try:
|
||||
logger.info("Fetching all IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
def _generate_schedule(
|
||||
self, tenant_ids: list[str] | list[None]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Given a list of tenant id's, generates a new beat schedule for celery."""
|
||||
logger.info("Fetching tasks to schedule")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
new_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if MULTI_TENANT:
|
||||
# cloud tasks only need the single task beat across all tenants
|
||||
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule",
|
||||
"get_cloud_tasks_to_schedule",
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
cloud_tasks_to_schedule: list[
|
||||
dict[str, Any]
|
||||
] = get_cloud_tasks_to_schedule()
|
||||
for task in cloud_tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
cloud_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
cloud_task["options"] = options
|
||||
new_schedule[task_name] = cloud_task
|
||||
|
||||
current_schedule = self.schedule.items()
|
||||
# regular task beats are multiplied across all tenants
|
||||
get_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
|
||||
tasks_to_schedule: list[dict[str, Any]] = get_tasks_to_schedule()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new item: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
logger.debug(f"Creating task configuration for {task_name}")
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
for tenant_id in tenant_ids:
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_beat_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
for task in tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
tenant_task_name = f"{task['name']}-{tenant_id}"
|
||||
|
||||
logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
tenant_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
new_schedule[tenant_task_name] = tenant_task
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
return new_schedule
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
def _try_updating_schedule(self) -> None:
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
logger.info("_try_updating_schedule starting")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
logger.debug("Comparing current and new schedules")
|
||||
current_tasks = set(name for name, _ in current_schedule)
|
||||
new_tasks = set(new_schedule.keys())
|
||||
needs_update = current_tasks != new_tasks
|
||||
logger.debug(f"Schedule update needed: {needs_update}")
|
||||
return needs_update
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
current_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
task_name = cast(str, task_name)
|
||||
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
continue
|
||||
|
||||
if "_" in task_name:
|
||||
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
current_tenants.add(task_name.split("_")[-1])
|
||||
logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in current_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
|
||||
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
|
||||
logger.info(
|
||||
"_try_updating_schedule: Current schedule is up to date, no changes needed"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
|
||||
# Create schedule entries
|
||||
entries = {}
|
||||
for name, entry in new_schedule.items():
|
||||
entries[name] = self.Entry(
|
||||
name=name,
|
||||
app=self.app,
|
||||
task=entry["task"],
|
||||
schedule=entry["schedule"],
|
||||
options=entry.get("options", {}),
|
||||
kwargs=entry.get("kwargs", {}),
|
||||
)
|
||||
|
||||
# Update the schedule using the scheduler's methods
|
||||
self.schedule.clear()
|
||||
self.schedule.update(entries)
|
||||
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("_try_updating_schedule: Schedule updated successfully")
|
||||
|
||||
@staticmethod
|
||||
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
|
||||
"""Compare schedules to determine if an update is needed.
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
if current_tasks != new_tasks:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
|
||||
@@ -17,7 +17,7 @@ from redis.lock import Lock as RedisLock
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.tasks.indexing.tasks import (
|
||||
from onyx.background.celery.tasks.indexing.utils import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
|
||||
@@ -2,25 +2,43 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
@@ -29,16 +47,7 @@ tasks_to_schedule = [
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
@@ -47,7 +56,7 @@ tasks_to_schedule = [
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
@@ -65,7 +74,7 @@ tasks_to_schedule = [
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
@@ -84,7 +93,7 @@ tasks_to_schedule = [
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
@@ -93,12 +102,25 @@ tasks_to_schedule = [
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
@@ -114,5 +136,9 @@ if LLM_MODEL_UPDATE_API_URL:
|
||||
)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -44,7 +44,7 @@ def check_for_connector_deletion_task(
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
|
||||
@@ -22,9 +22,9 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -99,7 +99,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
|
||||
@@ -22,7 +22,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -99,7 +99,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
519
backend/onyx/background/celery/tasks/indexing/utils.py
Normal file
519
backend/onyx/background/celery/tasks/indexing/utils.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
from onyx.db.index_attempt import delete_index_attempt
|
||||
from onyx.db.index_attempt import get_all_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Checks various global settings and past indexing attempts to determine if
|
||||
we should try to start indexing the cc pair / search setting combination.
|
||||
|
||||
Note that tactical checks such as preventing overlap with a currently running task
|
||||
are not handled here.
|
||||
|
||||
Return True if we should try to index, False if not.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
return False
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if (
|
||||
not cc_pair.status.is_active()
|
||||
or connector.id == 0
|
||||
or connector.source == DocumentSource.INGESTION_API
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexPayload(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
@@ -68,20 +68,22 @@ class Metric(BaseModel):
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
# Convert value to appropriate type
|
||||
float_value = (
|
||||
float(self.value) if isinstance(self.value, (int, float)) else None
|
||||
)
|
||||
int_value = int(self.value) if isinstance(self.value, int) else None
|
||||
string_value = str(self.value) if isinstance(self.value, str) else None
|
||||
bool_value = bool(self.value) if isinstance(self.value, bool) else None
|
||||
|
||||
if (
|
||||
float_value is None
|
||||
and int_value is None
|
||||
and string_value is None
|
||||
and bool_value is None
|
||||
):
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
int_value = None
|
||||
string_value = None
|
||||
# NOTE: have to do bool first, since `isinstance(True, int)` is true
|
||||
# e.g. bool is a subclass of int
|
||||
if isinstance(self.value, bool):
|
||||
bool_value = self.value
|
||||
elif isinstance(self.value, int):
|
||||
int_value = self.value
|
||||
elif isinstance(self.value, float):
|
||||
float_value = self.value
|
||||
elif isinstance(self.value, str):
|
||||
string_value = self.value
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid metric value type: {type(self.value)} "
|
||||
f"({self.value}) for metric {self.name}."
|
||||
@@ -183,35 +185,41 @@ def _build_connector_start_latency_metric(
|
||||
)
|
||||
|
||||
|
||||
def _build_run_success_metric(
|
||||
cc_pair: ConnectorCredentialPair, recent_attempt: IndexAttempt, redis_std: Redis
|
||||
) -> Metric | None:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=recent_attempt.id,
|
||||
)
|
||||
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {recent_attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
return None
|
||||
|
||||
if recent_attempt.status in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=recent_attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
def _build_run_success_metrics(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempts: list[IndexAttempt],
|
||||
redis_std: Redis,
|
||||
) -> list[Metric]:
|
||||
metrics = []
|
||||
for attempt in recent_attempts:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=attempt.id,
|
||||
)
|
||||
|
||||
return None
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
if attempt.status in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
@@ -224,7 +232,7 @@ def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Me
|
||||
|
||||
metrics = []
|
||||
for cc_pair in cc_pairs:
|
||||
# Get most recent attempt in the last hour
|
||||
# Get all attempts in the last hour
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
@@ -232,31 +240,29 @@ def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Me
|
||||
IndexAttempt.time_created >= one_hour_ago,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
)
|
||||
recent_attempt = recent_attempts[0] if recent_attempts else None
|
||||
most_recent_attempt = recent_attempts[0] if recent_attempts else None
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
# if no metric to emit, skip
|
||||
if not recent_attempt:
|
||||
if most_recent_attempt is None:
|
||||
continue
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, recent_attempt, second_most_recent_attempt, redis_std
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
run_success_metric = _build_run_success_metric(
|
||||
cc_pair, recent_attempt, redis_std
|
||||
run_success_metrics = _build_run_success_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
if run_success_metric:
|
||||
metrics.append(run_success_metric)
|
||||
metrics.extend(run_success_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@@ -13,11 +13,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.tasks import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -86,7 +86,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
|
||||
@@ -101,10 +101,17 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if cleaned_doc.title and "\x00" in cleaned_doc.title:
|
||||
logger.warning(
|
||||
f"NUL characters found in document title: {cleaned_doc.title}"
|
||||
)
|
||||
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
@@ -120,6 +127,9 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
@@ -277,8 +287,6 @@ def _run_indexing(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
tracer_counter = 0
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
@@ -347,16 +355,15 @@ def _run_indexing(
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
index_pipeline_result = indexing_pipeline(
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
net_doc_change += index_pipeline_result.new_docs
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -365,9 +372,6 @@ def _run_indexing(
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_docs_indexed(
|
||||
@@ -378,6 +382,9 @@ def _run_indexing(
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_prompts_by_ids
|
||||
from onyx.db.prompts import get_prompts_by_ids
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.persona import get_default_prompt__read_only
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -97,11 +98,12 @@ def compute_max_document_tokens(
|
||||
|
||||
|
||||
def compute_max_document_tokens_for_persona(
|
||||
db_session: Session,
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
|
||||
@@ -7,26 +7,6 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
|
||||
@@ -79,6 +79,8 @@ KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
# might be worth separating this timeout into separate timeouts for each situation
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
@@ -198,6 +200,7 @@ class SessionType(str, Enum):
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
MIXED = "mixed" # User likes some answers and dislikes other, used for chat session metrics
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
@@ -291,6 +294,8 @@ class OnyxRedisLocks:
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
@@ -304,6 +309,13 @@ class OnyxCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
# a prefix used to distinguish system wide tasks in the cloud
|
||||
ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
|
||||
|
||||
# the tenant id we use for system level redis operations
|
||||
ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -331,6 +343,8 @@ class OnyxCeleryTask:
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
|
||||
@@ -258,7 +258,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
user_emails.append(email)
|
||||
return user_emails
|
||||
|
||||
def _get_all_drive_ids(self) -> set[str]:
|
||||
def get_all_drive_ids(self) -> set[str]:
|
||||
primary_drive_service = get_drive_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
@@ -353,7 +353,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
|
||||
all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
all_drive_ids: set[str] = self.get_all_drive_ids()
|
||||
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
@@ -437,7 +437,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# If all 3 are true, we already yielded from get_all_files_for_oauth
|
||||
return
|
||||
|
||||
all_drive_ids = self._get_all_drive_ids()
|
||||
all_drive_ids = self.get_all_drive_ids()
|
||||
drive_ids_to_retrieve: set[str] = set()
|
||||
folder_ids_to_retrieve: set[str] = set()
|
||||
if self._requested_shared_drive_ids or self._requested_folder_ids:
|
||||
|
||||
@@ -252,6 +252,7 @@ def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
id=file["webViewLink"],
|
||||
perm_sync_data={
|
||||
"doc_id": file.get("id"),
|
||||
"drive_id": file.get("driveId"),
|
||||
"permissions": file.get("permissions", []),
|
||||
"permission_ids": file.get("permissionIds", []),
|
||||
"name": file.get("name"),
|
||||
|
||||
@@ -19,7 +19,7 @@ FILE_FIELDS = (
|
||||
"shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
|
||||
"nextPageToken, files(mimeType, driveId, id, name, permissions(emailAddress, type), "
|
||||
"permissionIds, webViewLink, owners(emailAddress))"
|
||||
)
|
||||
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"
|
||||
|
||||
@@ -12,8 +12,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.constants import SYSTEM_USER
|
||||
from onyx.db.constants import SystemUser
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -35,13 +33,8 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
if isinstance(user, SystemUser):
|
||||
if user is SYSTEM_USER:
|
||||
return stmt
|
||||
raise ValueError("Bad SystemUser object")
|
||||
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
@@ -101,7 +94,7 @@ def _add_user_filters(
|
||||
|
||||
def get_connector_credential_pairs_for_user(
|
||||
db_session: Session,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
@@ -112,7 +105,6 @@ def get_connector_credential_pairs_for_user(
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
@@ -123,11 +115,12 @@ def get_connector_credential_pairs(
|
||||
db_session: Session,
|
||||
ids: list[int] | None = None,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=SYSTEM_USER,
|
||||
ids=ids,
|
||||
)
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def add_deletion_failure_message(
|
||||
@@ -162,7 +155,7 @@ def get_connector_credential_pair_for_user(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
@@ -178,18 +171,17 @@ def get_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
return get_connector_credential_pair_for_user(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
user=SYSTEM_USER,
|
||||
)
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
@@ -203,11 +195,10 @@ def get_connector_credential_pair_from_id(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
return get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=SYSTEM_USER,
|
||||
)
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_last_successful_attempt_time(
|
||||
|
||||
@@ -1,11 +1 @@
|
||||
from typing import Final
|
||||
|
||||
|
||||
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
|
||||
|
||||
|
||||
class SystemUser:
|
||||
"""Represents the system user for internal operations"""
|
||||
|
||||
|
||||
SYSTEM_USER: Final = SystemUser()
|
||||
|
||||
@@ -14,8 +14,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from onyx.db.constants import SYSTEM_USER
|
||||
from onyx.db.constants import SystemUser
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -44,17 +42,11 @@ PUBLIC_CREDENTIAL_ID = 0
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
) -> Select:
|
||||
"""Attaches filters to the statement to ensure that the user can only
|
||||
access the appropriate credentials"""
|
||||
|
||||
if isinstance(user, SystemUser):
|
||||
if user is SYSTEM_USER:
|
||||
return stmt
|
||||
raise ValueError("Bad SystemUser object")
|
||||
|
||||
if user is None:
|
||||
if not DISABLE_AUTH:
|
||||
raise ValueError("Anonymous users are not allowed to access credentials")
|
||||
@@ -159,7 +151,7 @@ def fetch_credentials_for_user(
|
||||
|
||||
def fetch_credential_by_id_for_user(
|
||||
credential_id: int,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
get_editable: bool = True,
|
||||
) -> Credential | None:
|
||||
@@ -179,16 +171,16 @@ def fetch_credential_by_id(
|
||||
db_session: Session,
|
||||
credential_id: int,
|
||||
) -> Credential | None:
|
||||
return fetch_credential_by_id_for_user(
|
||||
credential_id=credential_id,
|
||||
user=SYSTEM_USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
|
||||
|
||||
def fetch_credentials_by_source_for_user(
|
||||
db_session: Session,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
document_source: DocumentSource | None = None,
|
||||
get_editable: bool = True,
|
||||
) -> list[Credential]:
|
||||
@@ -202,11 +194,9 @@ def fetch_credentials_by_source(
|
||||
db_session: Session,
|
||||
document_source: DocumentSource | None = None,
|
||||
) -> list[Credential]:
|
||||
return fetch_credentials_by_source_for_user(
|
||||
db_session=db_session,
|
||||
user=SYSTEM_USER,
|
||||
document_source=document_source,
|
||||
)
|
||||
base_query = select(Credential).where(Credential.source == document_source)
|
||||
credentials = db_session.execute(base_query).scalars().all()
|
||||
return list(credentials)
|
||||
|
||||
|
||||
def swap_credentials_connector(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -13,6 +14,7 @@ from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import tuple_
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
from sqlalchemy.exc import OperationalError
|
||||
@@ -226,10 +228,13 @@ def get_document_counts_for_cc_pairs(
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
tuple_(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
).in_(cc_ids)
|
||||
and_(
|
||||
tuple_(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
).in_(cc_ids),
|
||||
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
@@ -382,18 +387,40 @@ def upsert_document_by_connector_credential_pair(
|
||||
id=doc_id,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
has_been_indexed=False,
|
||||
)
|
||||
)
|
||||
for doc_id in document_ids
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
# needs to change to an `on_conflict_do_update`
|
||||
# this must be `on_conflict_do_nothing` rather than `on_conflict_do_update`
|
||||
# since we don't want to update the `has_been_indexed` field for documents
|
||||
# that already exist
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
document_ids: Iterable[str],
|
||||
) -> None:
|
||||
"""Should be called only after a successful index operation for a batch."""
|
||||
db_session.execute(
|
||||
update(DocumentByConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
DocumentByConnectorCredentialPair.id.in_(document_ids),
|
||||
)
|
||||
)
|
||||
.values(has_been_indexed=True)
|
||||
)
|
||||
|
||||
|
||||
def update_docs_updated_at__no_commit(
|
||||
ids_to_new_updated_at: dict[str, datetime],
|
||||
db_session: Session,
|
||||
|
||||
@@ -15,8 +15,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.constants import SYSTEM_USER
|
||||
from onyx.db.constants import SystemUser
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -37,13 +35,8 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
if isinstance(user, SystemUser):
|
||||
if user is SYSTEM_USER:
|
||||
return stmt
|
||||
raise ValueError("Bad SystemUser object")
|
||||
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
@@ -494,7 +487,7 @@ def fetch_document_sets(
|
||||
|
||||
def fetch_all_document_sets_for_user(
|
||||
db_session: Session,
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[DocumentSetDBModel]:
|
||||
stmt = select(DocumentSetDBModel).distinct()
|
||||
@@ -502,15 +495,6 @@ def fetch_all_document_sets_for_user(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_all_document_sets(
|
||||
db_session: Session,
|
||||
) -> Sequence[DocumentSetDBModel]:
|
||||
return fetch_all_document_sets_for_user(
|
||||
db_session=db_session,
|
||||
user=SYSTEM_USER,
|
||||
)
|
||||
|
||||
|
||||
def fetch_documents_for_document_set_paginated(
|
||||
document_set_id: int,
|
||||
db_session: Session,
|
||||
|
||||
@@ -240,8 +240,11 @@ class SqlEngine:
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
|
||||
@@ -941,6 +941,12 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
ForeignKey("credential.id"), primary_key=True
|
||||
)
|
||||
|
||||
# used to better keep track of document counts at a connector level
|
||||
# e.g. if a document is added as part of permission syncing, it should
|
||||
# not be counted as part of the connector's document count until
|
||||
# the actual indexing is complete
|
||||
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="documents_by_connector"
|
||||
)
|
||||
@@ -955,6 +961,14 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
"credential_id",
|
||||
unique=False,
|
||||
),
|
||||
# Index to optimize get_document_counts_for_cc_pairs query pattern
|
||||
Index(
|
||||
"idx_document_cc_pair_counts",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
"has_been_indexed",
|
||||
unique=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -8,7 +7,6 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import not_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
@@ -23,9 +21,6 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.constants import SYSTEM_USER
|
||||
from onyx.db.constants import SystemUser
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
@@ -37,8 +32,8 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -46,13 +41,8 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
if isinstance(user, SystemUser):
|
||||
if user is SYSTEM_USER:
|
||||
return stmt
|
||||
raise ValueError("Bad SystemUser object")
|
||||
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
@@ -114,14 +104,8 @@ def _add_user_filters(
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
|
||||
|
||||
|
||||
def fetch_persona_by_id_for_user(
|
||||
db_session: Session,
|
||||
persona_id: int,
|
||||
user: User | None | SystemUser,
|
||||
get_editable: bool = True,
|
||||
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
|
||||
) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
@@ -134,17 +118,6 @@ def fetch_persona_by_id_for_user(
|
||||
return persona
|
||||
|
||||
|
||||
def fetch_persona_by_id(
|
||||
db_session: Session,
|
||||
persona_id: int,
|
||||
) -> Persona:
|
||||
return fetch_persona_by_id_for_user(
|
||||
db_session=db_session,
|
||||
persona_id=persona_id,
|
||||
user=SYSTEM_USER,
|
||||
)
|
||||
|
||||
|
||||
def get_best_persona_id_for_user(
|
||||
db_session: Session, user: User | None, persona_id: int | None = None
|
||||
) -> int | None:
|
||||
@@ -205,7 +178,7 @@ def make_persona_private(
|
||||
|
||||
def create_update_persona(
|
||||
persona_id: int | None,
|
||||
create_persona_request: CreatePersonaRequest,
|
||||
create_persona_request: PersonaUpsertRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> PersonaSnapshot:
|
||||
@@ -213,14 +186,36 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
persona_data = {
|
||||
"persona_id": persona_id,
|
||||
"user": user,
|
||||
"db_session": db_session,
|
||||
**create_persona_request.model_dump(exclude={"users", "groups"}),
|
||||
}
|
||||
all_prompt_ids = create_persona_request.prompt_ids
|
||||
|
||||
persona = upsert_persona(**persona_data)
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
description=create_persona_request.description,
|
||||
name=create_persona_request.name,
|
||||
prompt_ids=all_prompt_ids,
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
icon_color=create_persona_request.icon_color,
|
||||
icon_shape=create_persona_request.icon_shape,
|
||||
uploaded_image_id=create_persona_request.uploaded_image_id,
|
||||
display_priority=create_persona_request.display_priority,
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
@@ -286,27 +281,9 @@ def update_persona_public_status(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_prompts(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_default: bool = True,
|
||||
include_deleted: bool = False,
|
||||
) -> Sequence[Prompt]:
|
||||
stmt = select(Prompt).where(
|
||||
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
|
||||
)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Prompt.default_prompt.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Prompt.deleted.is_(False))
|
||||
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_personas_for_user(
|
||||
# if user is `None` assume the user is an admin or auth is disabled
|
||||
user: User | None | SystemUser,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
get_editable: bool = True,
|
||||
include_default: bool = True,
|
||||
@@ -336,10 +313,10 @@ def get_personas_for_user(
|
||||
|
||||
|
||||
def get_personas(db_session: Session) -> Sequence[Persona]:
|
||||
return get_personas_for_user(
|
||||
user=SYSTEM_USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
stmt = select(Persona).distinct()
|
||||
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
||||
stmt = stmt.where(Persona.deleted.is_(False))
|
||||
return db_session.execute(stmt).unique().scalars().all()
|
||||
|
||||
|
||||
def mark_persona_as_deleted(
|
||||
@@ -395,65 +372,6 @@ def update_all_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_prompt(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
include_citations: bool,
|
||||
datetime_aware: bool,
|
||||
personas: list[Persona] | None,
|
||||
db_session: Session,
|
||||
prompt_id: int | None = None,
|
||||
default_prompt: bool = True,
|
||||
commit: bool = True,
|
||||
) -> Prompt:
|
||||
if prompt_id is not None:
|
||||
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
|
||||
else:
|
||||
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
|
||||
|
||||
if prompt:
|
||||
if not default_prompt and prompt.default_prompt:
|
||||
raise ValueError("Cannot update default prompt with non-default.")
|
||||
|
||||
prompt.name = name
|
||||
prompt.description = description
|
||||
prompt.system_prompt = system_prompt
|
||||
prompt.task_prompt = task_prompt
|
||||
prompt.include_citations = include_citations
|
||||
prompt.datetime_aware = datetime_aware
|
||||
prompt.default_prompt = default_prompt
|
||||
|
||||
if personas is not None:
|
||||
prompt.personas.clear()
|
||||
prompt.personas = personas
|
||||
|
||||
else:
|
||||
prompt = Prompt(
|
||||
id=prompt_id,
|
||||
user_id=user.id if user else None,
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
datetime_aware=datetime_aware,
|
||||
default_prompt=default_prompt,
|
||||
personas=personas or [],
|
||||
)
|
||||
db_session.add(prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
# Flush the session so that the Prompt has an ID
|
||||
db_session.flush()
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -498,6 +416,15 @@ def upsert_persona(
|
||||
persona_name=name, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
if existing_persona:
|
||||
# this checks if the user has permission to edit the persona
|
||||
# will raise an Exception if the user does not have permission
|
||||
existing_persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session,
|
||||
persona_id=existing_persona.id,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
# Fetch and attach tools by IDs
|
||||
tools = None
|
||||
if tool_ids is not None:
|
||||
@@ -543,15 +470,6 @@ def upsert_persona(
|
||||
if existing_persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
# this checks if the user has permission to edit the persona
|
||||
# will raise an Exception if the user does not have permission
|
||||
existing_persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session,
|
||||
persona_id=existing_persona.id,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
@@ -640,16 +558,6 @@ def upsert_persona(
|
||||
return persona
|
||||
|
||||
|
||||
def mark_prompt_as_deleted(
|
||||
prompt_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
|
||||
prompt.deleted = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_old_default_personas(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -687,69 +595,6 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_id(
|
||||
prompt_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
) -> Prompt:
|
||||
stmt = select(Prompt).where(Prompt.id == prompt_id)
|
||||
|
||||
# if user is not specified OR they are an admin, they should
|
||||
# have access to all prompts, so this where clause is not needed
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))
|
||||
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Prompt.deleted.is_(False))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise ValueError(
|
||||
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def _get_default_prompt(db_session: Session) -> Prompt:
|
||||
stmt = select(Prompt).where(Prompt.id == 0)
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise RuntimeError("Default Prompt not found")
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_default_prompt(db_session: Session) -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_default_prompt__read_only() -> Prompt:
|
||||
"""Due to the way lru_cache / SQLAlchemy works, this can cause issues
|
||||
when trying to attach the returned `Prompt` object to a `Persona`. If you are
|
||||
doing anything other than reading, you should use the `get_default_prompt`
|
||||
method instead."""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
|
||||
# a direct mapping indicating whether a user has access to a specific persona?
|
||||
def get_persona_by_id(
|
||||
@@ -821,22 +666,6 @@ def get_personas_by_ids(
|
||||
return personas
|
||||
|
||||
|
||||
def get_prompt_by_name(
|
||||
prompt_name: str, user: User | None, db_session: Session
|
||||
) -> Prompt | None:
|
||||
stmt = select(Prompt).where(Prompt.name == prompt_name)
|
||||
|
||||
# if user is not specified OR they are an admin, they should
|
||||
# have access to all prompts, so this where clause is not needed
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
def delete_persona_by_name(
|
||||
persona_name: str, db_session: Session, is_default: bool = True
|
||||
) -> None:
|
||||
|
||||
119
backend/onyx/db/prompts.py
Normal file
119
backend/onyx/db/prompts.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# Note: As prompts are fairly innocuous/harmless, there are no protections
|
||||
# to prevent users from messing with prompts of other users.
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_default_prompt(db_session: Session) -> Prompt:
|
||||
stmt = select(Prompt).where(Prompt.id == 0)
|
||||
result = db_session.execute(stmt)
|
||||
prompt = result.scalar_one_or_none()
|
||||
|
||||
if prompt is None:
|
||||
raise RuntimeError("Default Prompt not found")
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_default_prompt(db_session: Session) -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_name(
|
||||
prompt_name: str, user: User | None, db_session: Session
|
||||
) -> Prompt | None:
|
||||
stmt = select(Prompt).where(Prompt.name == prompt_name)
|
||||
|
||||
# if user is not specified OR they are an admin, they should
|
||||
# have access to all prompts, so this where clause is not needed
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
def build_prompt_name_from_persona_name(persona_name: str) -> str:
|
||||
return f"default-prompt__{persona_name}"
|
||||
|
||||
|
||||
def upsert_prompt(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
name: str,
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
datetime_aware: bool,
|
||||
prompt_id: int | None = None,
|
||||
personas: list[Persona] | None = None,
|
||||
include_citations: bool = False,
|
||||
default_prompt: bool = True,
|
||||
# Support backwards compatibility
|
||||
description: str | None = None,
|
||||
) -> Prompt:
|
||||
if description is None:
|
||||
description = f"Default prompt for {name}"
|
||||
|
||||
if prompt_id is not None:
|
||||
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
|
||||
else:
|
||||
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
|
||||
|
||||
if prompt:
|
||||
if not default_prompt and prompt.default_prompt:
|
||||
raise ValueError("Cannot update default prompt with non-default.")
|
||||
|
||||
prompt.name = name
|
||||
prompt.description = description
|
||||
prompt.system_prompt = system_prompt
|
||||
prompt.task_prompt = task_prompt
|
||||
prompt.include_citations = include_citations
|
||||
prompt.datetime_aware = datetime_aware
|
||||
prompt.default_prompt = default_prompt
|
||||
|
||||
if personas is not None:
|
||||
prompt.personas.clear()
|
||||
prompt.personas = personas
|
||||
|
||||
else:
|
||||
prompt = Prompt(
|
||||
id=prompt_id,
|
||||
user_id=user.id if user else None,
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
datetime_aware=datetime_aware,
|
||||
default_prompt=default_prompt,
|
||||
personas=personas or [],
|
||||
)
|
||||
db_session.add(prompt)
|
||||
|
||||
# Flush the session so that the Prompt has an ID
|
||||
db_session.flush()
|
||||
|
||||
return prompt
|
||||
@@ -12,9 +12,9 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__DocumentSet
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_default_prompt
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.utils.errors import EERequiredError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
from onyx.db.document import prepare_to_modify_documents
|
||||
from onyx.db.document import update_docs_chunk_count__no_commit
|
||||
from onyx.db.document import update_docs_last_modified__no_commit
|
||||
@@ -55,12 +56,23 @@ class DocumentBatchPrepareContext(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class IndexingPipelineResult(BaseModel):
|
||||
# number of documents that are completely new (e.g. did
|
||||
# not exist as a part of this OR any other connector)
|
||||
new_docs: int
|
||||
# NOTE: need total_docs, since the pipeline can skip some docs
|
||||
# (e.g. not even insert them into Postgres)
|
||||
total_docs: int
|
||||
# number of chunks that were inserted into Vespa
|
||||
total_chunks: int
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> tuple[int, int]:
|
||||
) -> IndexingPipelineResult:
|
||||
...
|
||||
|
||||
|
||||
@@ -147,10 +159,12 @@ def index_doc_batch_with_handler(
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> tuple[int, int]:
|
||||
r = (0, 0)
|
||||
) -> IndexingPipelineResult:
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0, total_docs=len(document_batch), total_chunks=0
|
||||
)
|
||||
try:
|
||||
r = index_doc_batch(
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
document_index=document_index,
|
||||
@@ -203,7 +217,7 @@ def index_doc_batch_with_handler(
|
||||
else:
|
||||
pass
|
||||
|
||||
return r
|
||||
return index_pipeline_result
|
||||
|
||||
|
||||
def index_doc_batch_prepare(
|
||||
@@ -227,6 +241,15 @@ def index_doc_batch_prepare(
|
||||
if not ignore_time_skip
|
||||
else documents
|
||||
)
|
||||
if len(updatable_docs) != len(documents):
|
||||
updatable_doc_ids = [doc.id for doc in updatable_docs]
|
||||
skipped_doc_ids = [
|
||||
doc.id for doc in documents if doc.id not in updatable_doc_ids
|
||||
]
|
||||
logger.info(
|
||||
f"Skipping {len(skipped_doc_ids)} documents "
|
||||
f"because they are up to date. Skipped doc IDs: {skipped_doc_ids}"
|
||||
)
|
||||
|
||||
# for all updatable docs, upsert into the DB
|
||||
# Does not include doc_updated_at which is also used to indicate a successful update
|
||||
@@ -263,21 +286,6 @@ def index_doc_batch_prepare(
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
# Remove any NUL characters from title/semantic_id
|
||||
# This is a known issue with the Zendesk connector
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if document.title:
|
||||
document.title = document.title.replace("\x00", "")
|
||||
if document.semantic_identifier:
|
||||
document.semantic_identifier = document.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
# Remove NUL characters from all sections
|
||||
for section in document.sections:
|
||||
if section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
@@ -333,7 +341,7 @@ def index_doc_batch(
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> tuple[int, int]:
|
||||
) -> IndexingPipelineResult:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements
|
||||
@@ -359,7 +367,18 @@ def index_doc_batch(
|
||||
db_session=db_session,
|
||||
)
|
||||
if not ctx:
|
||||
return 0, 0
|
||||
# even though we didn't actually index anything, we should still
|
||||
# mark them as "completed" for the CC Pair in order to make the
|
||||
# counts match
|
||||
mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
connector_id=index_attempt_metadata.connector_id,
|
||||
credential_id=index_attempt_metadata.credential_id,
|
||||
document_ids=[doc.id for doc in filtered_documents],
|
||||
db_session=db_session,
|
||||
)
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
)
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
@@ -425,7 +444,8 @@ def index_doc_batch(
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
|
||||
"Indexing the following chunks: "
|
||||
f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
|
||||
)
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
@@ -440,14 +460,17 @@ def index_doc_batch(
|
||||
),
|
||||
)
|
||||
|
||||
successful_doc_ids = [record.document_id for record in insertion_records]
|
||||
successful_docs = [
|
||||
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
|
||||
]
|
||||
successful_doc_ids = {record.document_id for record in insertion_records}
|
||||
if successful_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Successful IDs: {successful_doc_ids}"
|
||||
)
|
||||
|
||||
last_modified_ids = []
|
||||
ids_to_new_updated_at = {}
|
||||
for doc in successful_docs:
|
||||
for doc in ctx.updatable_docs:
|
||||
last_modified_ids.append(doc.id)
|
||||
# doc_updated_at is the source's idea (on the other end of the connector)
|
||||
# of when the doc was last modified
|
||||
@@ -469,11 +492,24 @@ def index_doc_batch(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# these documents can now be counted as part of the CC Pairs
|
||||
# document count, so we need to mark them as indexed
|
||||
# NOTE: even documents we skipped since they were already up
|
||||
# to date should be counted here in order to maintain parity
|
||||
# between CC Pair and index attempt counts
|
||||
mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
connector_id=index_attempt_metadata.connector_id,
|
||||
credential_id=index_attempt_metadata.credential_id,
|
||||
document_ids=[doc.id for doc in filtered_documents],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
result = (
|
||||
len([r for r in insertion_records if r.already_existed is False]),
|
||||
len(access_aware_chunks),
|
||||
result = IndexingPipelineResult(
|
||||
new_docs=len([r for r in insertion_records if r.already_existed is False]),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(access_aware_chunks),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -64,7 +64,6 @@ from onyx.server.features.input_prompt.api import (
|
||||
from onyx.server.features.notifications.api import router as notification_router
|
||||
from onyx.server.features.persona.api import admin_router as admin_persona_router
|
||||
from onyx.server.features.persona.api import basic_router as persona_router
|
||||
from onyx.server.features.prompt.api import basic_router as prompt_router
|
||||
from onyx.server.features.tool.api import admin_router as admin_tool_router
|
||||
from onyx.server.features.tool.api import router as tool_router
|
||||
from onyx.server.gpts.api import router as gpts_router
|
||||
@@ -296,7 +295,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, notification_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_tool_router)
|
||||
include_router_with_global_prefix_prepended(application, state_router)
|
||||
|
||||
@@ -118,32 +118,6 @@ You should always get right to the point, and never use extraneous language.
|
||||
"""
|
||||
|
||||
|
||||
# This is only for visualization for the users to specify their own prompts
|
||||
# The actual flow does not work like this
|
||||
PARAMATERIZED_PROMPT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
|
||||
# CURRENTLY DISABLED, CANNOT USE THIS ONE
|
||||
# Default chain-of-thought style json prompt which uses multiple docs
|
||||
# This one has a section for the LLM to output some non-answer "thoughts"
|
||||
|
||||
@@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -147,7 +147,7 @@ class RedisConnectorPermissionSync:
|
||||
for doc_perm in new_permissions:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
@@ -7,7 +7,7 @@ from celery import Celery
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -125,7 +125,7 @@ class RedisConnectorPrune:
|
||||
for doc_id in documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
@@ -12,9 +12,9 @@ from onyx.db.models import DocumentSet as DocumentSetDBModel
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Prompt as PromptDBModel
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.db.persona import get_prompt_by_name
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.persona import upsert_prompt
|
||||
from onyx.db.prompts import get_prompt_by_name
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
|
||||
|
||||
def load_prompts_from_yaml(
|
||||
@@ -26,6 +26,7 @@ def load_prompts_from_yaml(
|
||||
all_prompts = data.get("prompts", [])
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
db_session=db_session,
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
@@ -36,9 +37,8 @@ def load_prompts_from_yaml(
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
@@ -212,6 +213,7 @@ PaginatedType = TypeVar(
|
||||
IndexAttemptSnapshot,
|
||||
FullUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
ChatSessionMinimal,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.prompt_builder.utils import build_dummy_prompt
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
@@ -36,19 +35,21 @@ from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.db.prompts import build_prompt_name_from_persona_name
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.secondary_llm_flows.starter_message_creation import (
|
||||
generate_starter_messages,
|
||||
)
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
||||
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
||||
from onyx.server.features.persona.models import PersonaLabelCreate
|
||||
from onyx.server.features.persona.models import PersonaLabelResponse
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PromptTemplateResponse
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.features.persona.models import PromptSnapshot
|
||||
from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.tools.utils import is_image_generation_available
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -173,18 +174,37 @@ def upload_file(
|
||||
|
||||
@basic_router.post("")
|
||||
def create_persona(
|
||||
create_persona_request: CreatePersonaRequest,
|
||||
persona_upsert_request: PersonaUpsertRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> PersonaSnapshot:
|
||||
prompt_id = (
|
||||
persona_upsert_request.prompt_ids[0]
|
||||
if persona_upsert_request.prompt_ids
|
||||
and len(persona_upsert_request.prompt_ids) > 0
|
||||
else None
|
||||
)
|
||||
prompt = upsert_prompt(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
||||
system_prompt=persona_upsert_request.system_prompt,
|
||||
task_prompt=persona_upsert_request.task_prompt,
|
||||
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
|
||||
datetime_aware=False,
|
||||
include_citations=persona_upsert_request.include_citations,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
prompt_snapshot = PromptSnapshot.from_model(prompt)
|
||||
persona_upsert_request.prompt_ids = [prompt.id]
|
||||
persona_snapshot = create_update_persona(
|
||||
persona_id=None,
|
||||
create_persona_request=create_persona_request,
|
||||
create_persona_request=persona_upsert_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
persona_snapshot.prompts = [prompt_snapshot]
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
@@ -202,16 +222,37 @@ def create_persona(
|
||||
@basic_router.patch("/{persona_id}")
|
||||
def update_persona(
|
||||
persona_id: int,
|
||||
update_persona_request: CreatePersonaRequest,
|
||||
persona_upsert_request: PersonaUpsertRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
return create_update_persona(
|
||||
prompt_id = (
|
||||
persona_upsert_request.prompt_ids[0]
|
||||
if persona_upsert_request.prompt_ids
|
||||
and len(persona_upsert_request.prompt_ids) > 0
|
||||
else None
|
||||
)
|
||||
prompt = upsert_prompt(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
||||
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
|
||||
datetime_aware=False,
|
||||
system_prompt=persona_upsert_request.system_prompt,
|
||||
task_prompt=persona_upsert_request.task_prompt,
|
||||
include_citations=persona_upsert_request.include_citations,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
prompt_snapshot = PromptSnapshot.from_model(prompt)
|
||||
persona_upsert_request.prompt_ids = [prompt.id]
|
||||
persona_snapshot = create_update_persona(
|
||||
persona_id=persona_id,
|
||||
create_persona_request=update_persona_request,
|
||||
create_persona_request=persona_upsert_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
persona_snapshot.prompts = [prompt_snapshot]
|
||||
return persona_snapshot
|
||||
|
||||
|
||||
class PersonaLabelPatchRequest(BaseModel):
|
||||
@@ -365,22 +406,6 @@ def get_persona(
|
||||
)
|
||||
|
||||
|
||||
@basic_router.get("/utils/prompt-explorer")
|
||||
def build_final_template_prompt(
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
retrieval_disabled: bool = False,
|
||||
_: User | None = Depends(current_user),
|
||||
) -> PromptTemplateResponse:
|
||||
return PromptTemplateResponse(
|
||||
final_prompt_template=build_dummy_prompt(
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@basic_router.post("/assistant-prompt-refresh")
|
||||
def build_assistant_prompts(
|
||||
generate_persona_prompt_request: GenerateStarterMessageRequest,
|
||||
|
||||
@@ -7,9 +7,9 @@ from pydantic import Field
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import PersonaLabel
|
||||
from onyx.db.models import Prompt
|
||||
from onyx.db.models import StarterMessage
|
||||
from onyx.server.features.document_set.models import DocumentSet
|
||||
from onyx.server.features.prompt.models import PromptSnapshot
|
||||
from onyx.server.features.tool.models import ToolSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -18,6 +18,34 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PromptSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
include_citations: bool
|
||||
datetime_aware: bool
|
||||
default_prompt: bool
|
||||
# Not including persona info, not needed
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
|
||||
if prompt.deleted:
|
||||
raise ValueError("Prompt has been deleted")
|
||||
|
||||
return PromptSnapshot(
|
||||
id=prompt.id,
|
||||
name=prompt.name,
|
||||
description=prompt.description,
|
||||
system_prompt=prompt.system_prompt,
|
||||
task_prompt=prompt.task_prompt,
|
||||
include_citations=prompt.include_citations,
|
||||
datetime_aware=prompt.datetime_aware,
|
||||
default_prompt=prompt.default_prompt,
|
||||
)
|
||||
|
||||
|
||||
# More minimal request for generating a persona prompt
|
||||
class GenerateStarterMessageRequest(BaseModel):
|
||||
name: str
|
||||
@@ -27,32 +55,35 @@ class GenerateStarterMessageRequest(BaseModel):
|
||||
generation_count: int
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
include_citations: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
prompt_ids: list[int]
|
||||
document_set_ids: list[int]
|
||||
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
|
||||
tool_ids: list[int]
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
|
||||
tool_ids: list[int]
|
||||
icon_color: str | None = None
|
||||
icon_shape: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
remove_image: bool | None = None
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette import status
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.persona import get_prompt_by_id
|
||||
from onyx.db.persona import get_prompts
|
||||
from onyx.db.persona import mark_prompt_as_deleted
|
||||
from onyx.db.persona import upsert_prompt
|
||||
from onyx.server.features.prompt.models import CreatePromptRequest
|
||||
from onyx.server.features.prompt.models import PromptSnapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# Note: As prompts are fairly innocuous/harmless, there are no protections
|
||||
# to prevent users from messing with prompts of other users.
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/prompt")
|
||||
|
||||
|
||||
def create_update_prompt(
|
||||
prompt_id: int | None,
|
||||
create_prompt_request: CreatePromptRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> PromptSnapshot:
|
||||
personas = (
|
||||
list(
|
||||
get_personas_by_ids(
|
||||
persona_ids=create_prompt_request.persona_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
if create_prompt_request.persona_ids
|
||||
else []
|
||||
)
|
||||
|
||||
prompt = upsert_prompt(
|
||||
prompt_id=prompt_id,
|
||||
user=user,
|
||||
name=create_prompt_request.name,
|
||||
description=create_prompt_request.description,
|
||||
system_prompt=create_prompt_request.system_prompt,
|
||||
task_prompt=create_prompt_request.task_prompt,
|
||||
include_citations=create_prompt_request.include_citations,
|
||||
datetime_aware=create_prompt_request.datetime_aware,
|
||||
personas=personas,
|
||||
db_session=db_session,
|
||||
)
|
||||
return PromptSnapshot.from_model(prompt)
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_prompt(
|
||||
create_prompt_request: CreatePromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PromptSnapshot:
|
||||
try:
|
||||
return create_update_prompt(
|
||||
prompt_id=None,
|
||||
create_prompt_request=create_prompt_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as ve:
|
||||
logger.exception(ve)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to create Persona, invalid info.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@basic_router.patch("/{prompt_id}")
|
||||
def update_prompt(
|
||||
prompt_id: int,
|
||||
update_prompt_request: CreatePromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PromptSnapshot:
|
||||
try:
|
||||
return create_update_prompt(
|
||||
prompt_id=prompt_id,
|
||||
create_prompt_request=update_prompt_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as ve:
|
||||
logger.exception(ve)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to create Persona, invalid info.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
|
||||
|
||||
@basic_router.delete("/{prompt_id}")
|
||||
def delete_prompt(
|
||||
prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
mark_prompt_as_deleted(
|
||||
prompt_id=prompt_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def list_prompts(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PromptSnapshot]:
|
||||
user_id = user.id if user is not None else None
|
||||
return [
|
||||
PromptSnapshot.from_model(prompt)
|
||||
for prompt in get_prompts(user_id=user_id, db_session=db_session)
|
||||
]
|
||||
|
||||
|
||||
@basic_router.get("/{prompt_id}")
|
||||
def get_prompt(
|
||||
prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PromptSnapshot:
|
||||
return PromptSnapshot.from_model(
|
||||
get_prompt_by_id(
|
||||
prompt_id=prompt_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.models import Prompt
|
||||
|
||||
|
||||
class CreatePromptRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
include_citations: bool = False
|
||||
datetime_aware: bool = False
|
||||
persona_ids: list[int] | None = None
|
||||
|
||||
|
||||
class PromptSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
include_citations: bool
|
||||
datetime_aware: bool
|
||||
default_prompt: bool
|
||||
# Not including persona info, not needed
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
|
||||
if prompt.deleted:
|
||||
raise ValueError("Prompt has been deleted")
|
||||
|
||||
return PromptSnapshot(
|
||||
id=prompt.id,
|
||||
name=prompt.name,
|
||||
description=prompt.description,
|
||||
system_prompt=prompt.system_prompt,
|
||||
task_prompt=prompt.task_prompt,
|
||||
include_citations=prompt.include_citations,
|
||||
datetime_aware=prompt.datetime_aware,
|
||||
default_prompt=prompt.default_prompt,
|
||||
)
|
||||
@@ -27,6 +27,8 @@ from onyx.server.manage.models import SlackBot
|
||||
from onyx.server.manage.models import SlackBotCreationRequest
|
||||
from onyx.server.manage.models import SlackChannelConfig
|
||||
from onyx.server.manage.models import SlackChannelConfigCreationRequest
|
||||
from onyx.server.manage.validate_tokens import validate_app_token
|
||||
from onyx.server.manage.validate_tokens import validate_bot_token
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
|
||||
|
||||
@@ -222,6 +224,9 @@ def create_bot(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> SlackBot:
|
||||
validate_app_token(slack_bot_creation_request.app_token)
|
||||
validate_bot_token(slack_bot_creation_request.bot_token)
|
||||
|
||||
slack_bot_model = insert_slack_bot(
|
||||
db_session=db_session,
|
||||
name=slack_bot_creation_request.name,
|
||||
@@ -248,6 +253,8 @@ def patch_bot(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> SlackBot:
|
||||
validate_bot_token(slack_bot_creation_request.bot_token)
|
||||
validate_app_token(slack_bot_creation_request.app_token)
|
||||
slack_bot_model = update_slack_bot(
|
||||
db_session=db_session,
|
||||
slack_bot_id=slack_bot_id,
|
||||
|
||||
43
backend/onyx/server/manage/validate_tokens.py
Normal file
43
backend/onyx/server/manage/validate_tokens.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
SLACK_API_URL = "https://slack.com/api/auth.test"
|
||||
SLACK_CONNECTIONS_OPEN_URL = "https://slack.com/api/apps.connections.open"
|
||||
|
||||
|
||||
def validate_bot_token(bot_token: str) -> bool:
|
||||
headers = {"Authorization": f"Bearer {bot_token}"}
|
||||
response = requests.post(SLACK_API_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid bot token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_app_token(app_token: str) -> bool:
|
||||
headers = {"Authorization": f"Bearer {app_token}"}
|
||||
response = requests.post(SLACK_CONNECTIONS_OPEN_URL, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Error communicating with Slack API."
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if not data.get("ok", False):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid app token: {data.get('error', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -108,7 +108,7 @@ def upsert_ingestion_doc(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
new_doc, __chunk_count = indexing_pipeline(
|
||||
indexing_pipeline_result = indexing_pipeline(
|
||||
document_batch=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -150,4 +150,7 @@ def upsert_ingestion_doc(
|
||||
),
|
||||
)
|
||||
|
||||
return IngestionResult(document_id=document.id, already_existed=not bool(new_doc))
|
||||
return IngestionResult(
|
||||
document_id=document.id,
|
||||
already_existed=indexing_pipeline_result.new_docs > 0,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import get_personas_for_user
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.persona import upsert_prompt
|
||||
from onyx.db.prompts import upsert_prompt
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -479,7 +479,10 @@ def get_max_document_tokens(
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
return MaxSelectedDocumentTokens(
|
||||
max_tokens=compute_max_document_tokens_for_persona(persona),
|
||||
max_tokens=compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class LongTermLogger:
|
||||
def _cleanup_old_files(self, category_path: Path) -> None:
|
||||
try:
|
||||
files = sorted(
|
||||
category_path.glob("*.json"),
|
||||
[f for f in category_path.glob("*.json") if f.is_file()],
|
||||
key=lambda x: x.stat().st_mtime, # Sort by modification time
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -108,6 +108,7 @@ logger = getLogger(__name__)
|
||||
|
||||
|
||||
# class MessageSnapshot(BaseModel):
|
||||
# id: int
|
||||
# message: str
|
||||
# message_type: MessageType
|
||||
# documents: list[AbridgedSearchDoc]
|
||||
|
||||
@@ -133,3 +133,25 @@ class ChatSessionManager:
|
||||
)
|
||||
for msg in response.json()["messages"]
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def create_chat_message_feedback(
|
||||
message_id: int,
|
||||
is_positive: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
feedback_text: str | None = None,
|
||||
predefined_feedback: str | None = None,
|
||||
) -> None:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/chat/create-chat-message-feedback",
|
||||
json={
|
||||
"chat_message_id": message_id,
|
||||
"is_positive": is_positive,
|
||||
"feedback_text": feedback_text,
|
||||
"predefined_feedback": predefined_feedback,
|
||||
},
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestPersona
|
||||
@@ -16,6 +18,9 @@ class PersonaManager:
|
||||
def create(
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
include_citations: bool = False,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
@@ -28,32 +33,38 @@ class PersonaManager:
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
category_id: int | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
system_prompt = system_prompt or f"System prompt for {name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {name}"
|
||||
|
||||
persona_creation_request = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"num_chunks": num_chunks,
|
||||
"llm_relevance_filter": llm_relevance_filter,
|
||||
"is_public": is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction,
|
||||
"recency_bias": recency_bias,
|
||||
"prompt_ids": prompt_ids or [0],
|
||||
"document_set_ids": document_set_ids or [],
|
||||
"tool_ids": tool_ids or [],
|
||||
"llm_model_provider_override": llm_model_provider_override,
|
||||
"llm_model_version_override": llm_model_version_override,
|
||||
"users": users or [],
|
||||
"groups": groups or [],
|
||||
}
|
||||
persona_creation_request = PersonaUpsertRequest(
|
||||
name=name,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
prompt_ids=prompt_ids or [0],
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
users=[UUID(user) for user in (users or [])],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request,
|
||||
json=persona_creation_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -77,6 +88,7 @@ class PersonaManager:
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -84,6 +96,9 @@ class PersonaManager:
|
||||
persona: DATestPersona,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
include_citations: bool = False,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
@@ -96,32 +111,38 @@ class PersonaManager:
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestPersona:
|
||||
persona_update_request = {
|
||||
"name": name or persona.name,
|
||||
"description": description or persona.description,
|
||||
"num_chunks": num_chunks or persona.num_chunks,
|
||||
"llm_relevance_filter": llm_relevance_filter
|
||||
or persona.llm_relevance_filter,
|
||||
"is_public": is_public or persona.is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
persona_update_request = PersonaUpsertRequest(
|
||||
name=name or persona.name,
|
||||
description=description or persona.description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
include_citations=include_citations,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=is_public or persona.is_public,
|
||||
llm_filter_extraction=llm_filter_extraction
|
||||
or persona.llm_filter_extraction,
|
||||
"recency_bias": recency_bias or persona.recency_bias,
|
||||
"prompt_ids": prompt_ids or persona.prompt_ids,
|
||||
"document_set_ids": document_set_ids or persona.document_set_ids,
|
||||
"tool_ids": tool_ids or persona.tool_ids,
|
||||
"llm_model_provider_override": llm_model_provider_override
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
prompt_ids=prompt_ids or persona.prompt_ids,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=llm_model_provider_override
|
||||
or persona.llm_model_provider_override,
|
||||
"llm_model_version_override": llm_model_version_override
|
||||
llm_model_version_override=llm_model_version_override
|
||||
or persona.llm_model_version_override,
|
||||
"users": users or persona.users,
|
||||
"groups": groups or persona.groups,
|
||||
}
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request,
|
||||
json=persona_update_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -137,8 +158,8 @@ class PersonaManager:
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=updated_persona_data["recency_bias"],
|
||||
prompt_ids=updated_persona_data["prompts"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
prompt_ids=[prompt["id"] for prompt in updated_persona_data["prompts"]],
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
@@ -149,6 +170,7 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -164,12 +186,29 @@ class PersonaManager:
|
||||
response.raise_for_status()
|
||||
return [PersonaSnapshot(**persona) for persona in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def get_one(
|
||||
persona_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[PersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [PersonaSnapshot(**response.json())]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
persona: DATestPersona,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
all_personas = PersonaManager.get_all(user_performing_action)
|
||||
all_personas = PersonaManager.get_one(
|
||||
persona_id=persona.id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
for fetched_persona in all_personas:
|
||||
if fetched_persona.id == persona.id:
|
||||
return (
|
||||
@@ -199,6 +238,7 @@ class PersonaManager:
|
||||
and set(user.email for user in fetched_persona.users)
|
||||
== set(persona.users)
|
||||
and set(fetched_persona.groups) == set(persona.groups)
|
||||
and set(fetched_persona.labels) == set(persona.label_ids)
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from requests.models import CaseInsensitiveDict
|
||||
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class QueryHistoryManager:
|
||||
@staticmethod
|
||||
def get_query_history_page(
|
||||
page_num: int = 0,
|
||||
page_size: int = 10,
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
query_params: dict[str, str | int] = {
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
}
|
||||
if feedback_type:
|
||||
query_params["feedback_type"] = feedback_type.value
|
||||
if start_time:
|
||||
query_params["start_time"] = start_time.isoformat()
|
||||
if end_time:
|
||||
query_params["end_time"] = end_time.isoformat()
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return PaginatedReturn(
|
||||
items=[ChatSessionMinimal(**item) for item in data["items"]],
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID | str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> ChatSessionSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return ChatSessionSnapshot(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_query_history_as_csv(
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> tuple[CaseInsensitiveDict[str], str]:
|
||||
query_params: dict[str, str | int] = {}
|
||||
if start_time:
|
||||
query_params["start"] = start_time.isoformat()
|
||||
if end_time:
|
||||
query_params["end"] = end_time.isoformat()
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.headers, response.content.decode()
|
||||
@@ -213,17 +213,16 @@ class UserManager:
|
||||
is_active_filter: bool | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[FullUserSnapshot]:
|
||||
query_params = {
|
||||
query_params: dict[str, str | list[str] | int] = {
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"q": search_query if search_query else None,
|
||||
"roles": [role.value for role in role_filter] if role_filter else None,
|
||||
"is_active": is_active_filter if is_active_filter is not None else None,
|
||||
}
|
||||
# Remove None values
|
||||
query_params = {
|
||||
key: value for key, value in query_params.items() if value is not None
|
||||
}
|
||||
if search_query:
|
||||
query_params["q"] = search_query
|
||||
if role_filter:
|
||||
query_params["roles"] = [role.value for role in role_filter]
|
||||
if is_active_filter is not None:
|
||||
query_params["is_active"] = is_active_filter
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users/accepted?{urlencode(query_params, doseq=True)}",
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
@@ -127,14 +128,7 @@ class DATestPersona(BaseModel):
|
||||
llm_model_version_override: str | None
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
category_id: int | None = None
|
||||
|
||||
|
||||
#
|
||||
class DATestChatSession(BaseModel):
|
||||
id: UUID
|
||||
persona_id: int
|
||||
description: str
|
||||
label_ids: list[int]
|
||||
|
||||
|
||||
class DATestChatMessage(BaseModel):
|
||||
@@ -144,6 +138,16 @@ class DATestChatMessage(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class DATestChatSession(BaseModel):
|
||||
id: UUID
|
||||
persona_id: int
|
||||
description: str
|
||||
|
||||
|
||||
class DAQueryHistoryEntry(DATestChatSession):
|
||||
feedback_type: QAFeedbackType | None
|
||||
|
||||
|
||||
class StreamedResponse(BaseModel):
|
||||
full_message: str = ""
|
||||
rephrased_query: str | None = None
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
This file tests the permissions for creating and editing personas for different user roles:
|
||||
- Basic users can create personas and edit their own
|
||||
- Curators can edit personas that belong exclusively to groups they curate
|
||||
- Admins can edit all personas
|
||||
"""
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_persona_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a curator user
|
||||
curator: DATestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a basic user
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user")
|
||||
|
||||
# Creating user groups
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="curated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
# Setting the user as a curator for the user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Creating another user group that the user is not a curator of
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="uncurated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
"""Test that any user can create a persona"""
|
||||
# Basic user creates a persona
|
||||
basic_user_persona = PersonaManager.create(
|
||||
name="basic_user_persona",
|
||||
description="A persona created by basic user",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
|
||||
|
||||
# Curator creates a persona
|
||||
curator_persona = PersonaManager.create(
|
||||
name="curator_persona",
|
||||
description="A persona created by curator",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
PersonaManager.verify(curator_persona, user_performing_action=curator)
|
||||
|
||||
# Admin creates personas for different groups
|
||||
admin_persona_group_1 = PersonaManager.create(
|
||||
name="admin_persona_group_1",
|
||||
description="A persona for group 1",
|
||||
is_public=False,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
admin_persona_group_2 = PersonaManager.create(
|
||||
name="admin_persona_group_2",
|
||||
description="A persona for group 2",
|
||||
is_public=False,
|
||||
groups=[user_group_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
admin_persona_both_groups = PersonaManager.create(
|
||||
name="admin_persona_both_groups",
|
||||
description="A persona for both groups",
|
||||
is_public=False,
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
"""Test that users can edit their own personas"""
|
||||
# Basic user can edit their own persona
|
||||
PersonaManager.edit(
|
||||
persona=basic_user_persona,
|
||||
description="Updated description by basic user",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
|
||||
|
||||
# Basic user cannot edit other's personas
|
||||
with pytest.raises(HTTPError):
|
||||
PersonaManager.edit(
|
||||
persona=curator_persona,
|
||||
description="Invalid edit by basic user",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
"""Test curator permissions"""
|
||||
# Curator can edit personas that belong exclusively to groups they curate
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_1,
|
||||
description="Updated by curator",
|
||||
user_performing_action=curator,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_group_1, user_performing_action=curator)
|
||||
|
||||
# Curator cannot edit personas in groups they don't curate
|
||||
with pytest.raises(HTTPError):
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_2,
|
||||
description="Invalid edit by curator",
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curator cannot edit personas that belong to multiple groups, even if they curate one
|
||||
with pytest.raises(HTTPError):
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_both_groups,
|
||||
description="Invalid edit by curator",
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
"""Test admin permissions"""
|
||||
# Admin can edit any persona
|
||||
PersonaManager.edit(
|
||||
persona=basic_user_persona,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=curator_persona,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(curator_persona, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_1,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_group_1, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_group_2,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_group_2, user_performing_action=admin_user)
|
||||
|
||||
PersonaManager.edit(
|
||||
persona=admin_persona_both_groups,
|
||||
description="Updated by admin",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(admin_persona_both_groups, user_performing_action=admin_user)
|
||||
@@ -3,16 +3,15 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -69,66 +68,52 @@ def test_chat_history_endpoints(
|
||||
) -> None:
|
||||
admin_user, first_chat_id = setup_chat_session
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/chat-session-history",
|
||||
headers=admin_user.headers,
|
||||
# Get chat history
|
||||
history_response = QueryHistoryManager.get_query_history_page(
|
||||
user_performing_action=admin_user
|
||||
)
|
||||
assert response.status_code == 200
|
||||
history_response = response.json()
|
||||
|
||||
# Verify we got back the one chat session we created
|
||||
assert len(history_response) == 1
|
||||
assert len(history_response.items) == 1
|
||||
|
||||
# Verify the first chat session details
|
||||
first_session = history_response[0]
|
||||
assert first_session["user_email"] == admin_user.email
|
||||
assert first_session["name"] == "Test chat session"
|
||||
assert first_session["first_user_message"] == "What was the Q1 revenue?"
|
||||
assert first_session["first_ai_message"] is not None
|
||||
assert first_session["assistant_id"] == 0
|
||||
assert first_session["feedback_type"] is None
|
||||
assert first_session["flow_type"] == SessionType.CHAT.value
|
||||
assert first_session["conversation_length"] == 4 # 2 User messages + 2 AI responses
|
||||
first_session = history_response.items[0]
|
||||
assert first_session.user_email == admin_user.email
|
||||
assert first_session.name == "Test chat session"
|
||||
assert first_session.first_user_message == "What was the Q1 revenue?"
|
||||
assert first_session.first_ai_message is not None
|
||||
assert first_session.assistant_id == 0
|
||||
assert first_session.feedback_type is None
|
||||
assert first_session.flow_type == SessionType.CHAT
|
||||
assert first_session.conversation_length == 4 # 2 User messages + 2 AI responses
|
||||
|
||||
# Test date filtering - should return no results
|
||||
past_end = datetime.now(tz=timezone.utc) - timedelta(days=1)
|
||||
past_start = past_end - timedelta(days=1)
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/chat-session-history",
|
||||
params={
|
||||
"start": past_start.isoformat(),
|
||||
"end": past_end.isoformat(),
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
history_response = QueryHistoryManager.get_query_history_page(
|
||||
start_time=past_start,
|
||||
end_time=past_end,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
history_response = response.json()
|
||||
assert len(history_response) == 0
|
||||
assert len(history_response.items) == 0
|
||||
|
||||
# Test get specific chat session endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/chat-session-history/{first_chat_id}",
|
||||
headers=admin_user.headers,
|
||||
session_details = QueryHistoryManager.get_chat_session_admin(
|
||||
chat_session_id=first_chat_id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
session_details = response.json()
|
||||
|
||||
# Verify the session details
|
||||
assert session_details["id"] == first_chat_id
|
||||
assert len(session_details["messages"]) > 0
|
||||
assert session_details["flow_type"] == SessionType.CHAT.value
|
||||
assert str(session_details.id) == first_chat_id
|
||||
assert len(session_details.messages) > 0
|
||||
assert session_details.flow_type == SessionType.CHAT
|
||||
|
||||
# Test filtering by feedback
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/chat-session-history",
|
||||
params={
|
||||
"feedback_type": QAFeedbackType.LIKE.value,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
history_response = QueryHistoryManager.get_query_history_page(
|
||||
feedback_type=QAFeedbackType.LIKE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
history_response = response.json()
|
||||
assert len(history_response) == 0
|
||||
assert len(history_response.items) == 0
|
||||
|
||||
|
||||
def test_chat_history_csv_export(
|
||||
@@ -137,16 +122,13 @@ def test_chat_history_csv_export(
|
||||
admin_user, _ = setup_chat_session
|
||||
|
||||
# Test CSV export endpoint with date filtering
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/query-history-csv",
|
||||
headers=admin_user.headers,
|
||||
headers, csv_content = QueryHistoryManager.get_query_history_as_csv(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["Content-Type"] == "text/csv; charset=utf-8"
|
||||
assert "Content-Disposition" in response.headers
|
||||
assert headers["Content-Type"] == "text/csv; charset=utf-8"
|
||||
assert "Content-Disposition" in headers
|
||||
|
||||
# Verify CSV content
|
||||
csv_content = response.content.decode()
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 3 # Header + 2 QA pairs
|
||||
assert "chat_session_id" in csv_content
|
||||
@@ -158,15 +140,10 @@ def test_chat_history_csv_export(
|
||||
# Test CSV export with date filtering - should return no results
|
||||
past_end = datetime.now(tz=timezone.utc) - timedelta(days=1)
|
||||
past_start = past_end - timedelta(days=1)
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/query-history-csv",
|
||||
params={
|
||||
"start": past_start.isoformat(),
|
||||
"end": past_end.isoformat(),
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
headers, csv_content = QueryHistoryManager.get_query_history_as_csv(
|
||||
start_time=past_start,
|
||||
end_time=past_end,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
csv_content = response.content.decode()
|
||||
csv_lines = csv_content.strip().split("\n")
|
||||
assert len(csv_lines) == 1 # Only header, no data rows
|
||||
@@ -0,0 +1,112 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
|
||||
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.tests.query_history.utils import (
|
||||
setup_chat_sessions_with_different_feedback,
|
||||
)
|
||||
|
||||
|
||||
def _verify_query_history_pagination(
|
||||
chat_sessions: list[DAQueryHistoryEntry],
|
||||
page_size: int = 5,
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_sessions: list[str] = []
|
||||
|
||||
for i in range(0, len(chat_sessions), page_size):
|
||||
paginated_result = QueryHistoryManager.get_query_history_page(
|
||||
page_num=i // page_size,
|
||||
page_size=page_size,
|
||||
feedback_type=feedback_type,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
# Verify that the total items is equal to the length of the chat sessions list
|
||||
assert paginated_result.total_items == len(chat_sessions)
|
||||
# Verify that the number of items in the page is equal to the page size
|
||||
assert len(paginated_result.items) == min(page_size, len(chat_sessions) - i)
|
||||
# Add the retrieved chat sessions to the list of retrieved sessions
|
||||
retrieved_sessions.extend(
|
||||
[str(session.id) for session in paginated_result.items]
|
||||
)
|
||||
|
||||
# Create a set of all the expected chat session IDs
|
||||
all_expected_sessions = set(str(session.id) for session in chat_sessions)
|
||||
# Create a set of all the retrieved chat session IDs
|
||||
all_retrieved_sessions = set(retrieved_sessions)
|
||||
|
||||
# Verify that the set of retrieved sessions is equal to the set of expected sessions
|
||||
assert all_expected_sessions == all_retrieved_sessions
|
||||
|
||||
|
||||
def test_query_history_pagination(reset: None) -> None:
|
||||
(
|
||||
admin_user,
|
||||
chat_sessions_by_feedback_type,
|
||||
) = setup_chat_sessions_with_different_feedback()
|
||||
|
||||
all_chat_sessions = []
|
||||
for _, chat_sessions in chat_sessions_by_feedback_type.items():
|
||||
all_chat_sessions.extend(chat_sessions)
|
||||
|
||||
# Verify basic pagination with different page sizes
|
||||
print("Verifying basic pagination with page size 5")
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=all_chat_sessions,
|
||||
page_size=5,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
print("Verifying basic pagination with page size 10")
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=all_chat_sessions,
|
||||
page_size=10,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
print("Verifying pagination with feedback type LIKE")
|
||||
liked_sessions = chat_sessions_by_feedback_type[QAFeedbackType.LIKE]
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=liked_sessions,
|
||||
feedback_type=QAFeedbackType.LIKE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
print("Verifying pagination with feedback type DISLIKE")
|
||||
disliked_sessions = chat_sessions_by_feedback_type[QAFeedbackType.DISLIKE]
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=disliked_sessions,
|
||||
feedback_type=QAFeedbackType.DISLIKE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
print("Verifying pagination with feedback type MIXED")
|
||||
mixed_sessions = chat_sessions_by_feedback_type[QAFeedbackType.MIXED]
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=mixed_sessions,
|
||||
feedback_type=QAFeedbackType.MIXED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Test with a small page size to verify handling of partial pages
|
||||
print("Verifying pagination with page size 3")
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=all_chat_sessions,
|
||||
page_size=3,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Test with a page size larger than the total number of items
|
||||
print("Verifying pagination with page size 50")
|
||||
_verify_query_history_pagination(
|
||||
chat_sessions=all_chat_sessions,
|
||||
page_size=50,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
131
backend/tests/integration/tests/query_history/utils.py
Normal file
131
backend/tests/integration/tests/query_history/utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _create_chat_session_with_feedback(
|
||||
admin_user: DATestUser,
|
||||
i: int,
|
||||
feedback_type: QAFeedbackType | None,
|
||||
) -> tuple[QAFeedbackType | None, DAQueryHistoryEntry]:
|
||||
print(f"Creating chat session {i} with feedback type {feedback_type}")
|
||||
# Create chat session with timestamp spread over 30 days
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description=f"Test chat session {i}",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
test_session = DAQueryHistoryEntry(
|
||||
id=chat_session.id,
|
||||
persona_id=0,
|
||||
description=f"Test chat session {i}",
|
||||
feedback_type=feedback_type,
|
||||
)
|
||||
|
||||
# First message in chat
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=f"Question {i}?",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
messages = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
if feedback_type == QAFeedbackType.MIXED or feedback_type == QAFeedbackType.DISLIKE:
|
||||
ChatSessionManager.create_chat_message_feedback(
|
||||
message_id=messages[-1].id,
|
||||
is_positive=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Second message with different feedback types
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=f"Follow up {i}?",
|
||||
user_performing_action=admin_user,
|
||||
parent_message_id=messages[-1].id,
|
||||
)
|
||||
|
||||
# Get updated messages to get the ID of the second message
|
||||
messages = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
if feedback_type == QAFeedbackType.MIXED or feedback_type == QAFeedbackType.LIKE:
|
||||
ChatSessionManager.create_chat_message_feedback(
|
||||
message_id=messages[-1].id,
|
||||
is_positive=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
return feedback_type, test_session
|
||||
|
||||
|
||||
def setup_chat_sessions_with_different_feedback() -> (
|
||||
tuple[DATestUser, dict[QAFeedbackType | None, list[DAQueryHistoryEntry]]]
|
||||
):
|
||||
# Create admin user and required resources
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin_user)
|
||||
api_key = APIKeyManager.create(user_performing_action=admin_user)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# Seed a document
|
||||
cc_pair.documents = []
|
||||
cc_pair.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair,
|
||||
content="The company's revenue in Q1 was $1M",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
|
||||
chat_sessions_by_feedback_type: dict[
|
||||
QAFeedbackType | None, list[DAQueryHistoryEntry]
|
||||
] = {}
|
||||
# Use ThreadPoolExecutor to create chat sessions in parallel
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
# Submit all tasks and store futures
|
||||
j = 0
|
||||
# Will result in 40 sessions
|
||||
number_of_sessions = 10
|
||||
futures = []
|
||||
for feedback_type in [
|
||||
QAFeedbackType.MIXED,
|
||||
QAFeedbackType.LIKE,
|
||||
QAFeedbackType.DISLIKE,
|
||||
None,
|
||||
]:
|
||||
futures.extend(
|
||||
[
|
||||
executor.submit(
|
||||
_create_chat_session_with_feedback,
|
||||
admin_user,
|
||||
(j * number_of_sessions) + i,
|
||||
feedback_type,
|
||||
)
|
||||
for i in range(number_of_sessions)
|
||||
]
|
||||
)
|
||||
j += 1
|
||||
|
||||
# Collect results in order
|
||||
for future in as_completed(futures):
|
||||
feedback_type, chat_session = future.result()
|
||||
chat_sessions_by_feedback_type.setdefault(feedback_type, []).append(
|
||||
chat_session
|
||||
)
|
||||
|
||||
return admin_user, chat_sessions_by_feedback_type
|
||||
@@ -43,12 +43,6 @@ def _verify_user_pagination(
|
||||
assert all_expected_emails == all_retrieved_emails
|
||||
|
||||
|
||||
def _verify_user_role_and_status(users: list) -> None:
|
||||
for user in users:
|
||||
assert UserManager.is_role(user, user.role)
|
||||
assert UserManager.is_status(user, user.is_active)
|
||||
|
||||
|
||||
def test_user_pagination(reset: None) -> None:
|
||||
# Create an admin user to perform actions
|
||||
user_performing_action: DATestUser = UserManager.create(
|
||||
@@ -108,7 +102,13 @@ def test_user_pagination(reset: None) -> None:
|
||||
+ inactive_admins
|
||||
+ searchable_curators
|
||||
)
|
||||
_verify_user_role_and_status(all_users)
|
||||
for user in all_users:
|
||||
# Verify that the user's role in the db matches
|
||||
# the role in the user object
|
||||
assert UserManager.is_role(user, user.role)
|
||||
# Verify that the user's status in the db matches
|
||||
# the status in the user object
|
||||
assert UserManager.is_status(user, user.is_active)
|
||||
|
||||
# Verify pagination
|
||||
_verify_user_pagination(
|
||||
|
||||
11
web/jest.config.js
Normal file
11
web/jest.config.js
Normal file
@@ -0,0 +1,11 @@
|
||||
module.exports = {
|
||||
preset: "ts-jest",
|
||||
testEnvironment: "node",
|
||||
moduleNameMapper: {
|
||||
"^@/(.*)$": "<rootDir>/src/$1",
|
||||
},
|
||||
testPathIgnorePatterns: ["/node_modules/", "/tests/e2e/"],
|
||||
transform: {
|
||||
"^.+\\.tsx?$": "ts-jest",
|
||||
},
|
||||
};
|
||||
2773
web/package-lock.json
generated
2773
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,8 @@
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint"
|
||||
"lint": "next lint",
|
||||
"test": "jest"
|
||||
},
|
||||
"dependencies": {
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
@@ -84,10 +85,13 @@
|
||||
"@chromatic-com/playwright": "^0.10.0",
|
||||
"@tailwindcss/typography": "^0.5.10",
|
||||
"@types/chrome": "^0.0.287",
|
||||
"@types/jest": "^29.5.14",
|
||||
"chromatic": "^11.18.1",
|
||||
"eslint": "^8.48.0",
|
||||
"eslint-config-next": "^14.1.0",
|
||||
"prettier": "2.8.8"
|
||||
"jest": "^29.7.0",
|
||||
"prettier": "2.8.8",
|
||||
"ts-jest": "^29.2.5"
|
||||
},
|
||||
"overrides": {
|
||||
"react-is": "^19.0.0-rc-69d4b800-20241021"
|
||||
|
||||
@@ -6,14 +6,11 @@ import { generateRandomIconShape } from "@/lib/assistantIconUtils";
|
||||
import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||
import { ArrayHelpers, FieldArray, Form, Formik, FormikProps } from "formik";
|
||||
|
||||
import {
|
||||
BooleanFormField,
|
||||
Label,
|
||||
SelectorFormField,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
|
||||
@@ -42,14 +39,10 @@ import { FiInfo, FiRefreshCcw, FiUsers } from "react-icons/fi";
|
||||
import * as Yup from "yup";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
import {
|
||||
Persona,
|
||||
PersonaLabel,
|
||||
StarterMessage,
|
||||
StarterMessageBase,
|
||||
} from "./interfaces";
|
||||
import { Persona, PersonaLabel, StarterMessage } from "./interfaces";
|
||||
import {
|
||||
createPersonaLabel,
|
||||
PersonaUpsertParameters,
|
||||
createPersona,
|
||||
deletePersonaLabel,
|
||||
updatePersonaLabel,
|
||||
@@ -67,30 +60,19 @@ import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { debounce } from "lodash";
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import StarterMessagesList from "./StarterMessageList";
|
||||
import { LabelCard } from "./LabelCard";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
import { generateIdenticon } from "@/components/assistants/AssistantIcon";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { AssistantVisibilityPopover } from "@/app/assistants/mine/AssistantVisibilityPopover";
|
||||
import { MinimalUserSnapshot } from "@/lib/types";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { useUsers } from "@/lib/hooks";
|
||||
import { AllUsersResponse } from "@/lib/types";
|
||||
// import { Badge } from "@/components/ui/Badge";
|
||||
// import {
|
||||
// addUsersToAssistantSharedList,
|
||||
// shareAssistantWithGroups,
|
||||
// } from "@/lib/assistants/shareAssistant";
|
||||
import {
|
||||
SearchMultiSelectDropdown,
|
||||
Option as DropdownOption,
|
||||
} from "@/components/Dropdown";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { SourceChip } from "@/app/chat/input/ChatInputBar";
|
||||
import { GroupIcon, TagIcon, UserIcon } from "lucide-react";
|
||||
import { TagIcon, UserIcon } from "lucide-react";
|
||||
import { LLMSelector } from "@/components/llm/LLMSelector";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
@@ -268,7 +250,6 @@ export function AssistantEditor({
|
||||
labels: existingPersona?.labels ?? null,
|
||||
|
||||
// EE Only
|
||||
groups: existingPersona?.groups ?? [],
|
||||
label_ids: existingPersona?.labels?.map((label) => label.id) ?? [],
|
||||
selectedUsers:
|
||||
existingPersona?.users?.filter(
|
||||
@@ -418,7 +399,6 @@ export function AssistantEditor({
|
||||
icon_shape: Yup.number(),
|
||||
uploaded_image: Yup.mixed().nullable(),
|
||||
// EE Only
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
label_ids: Yup.array().of(Yup.number()),
|
||||
selectedUsers: Yup.array().of(Yup.object()),
|
||||
selectedGroups: Yup.array().of(Yup.number()),
|
||||
@@ -494,12 +474,13 @@ export function AssistantEditor({
|
||||
}));
|
||||
|
||||
// don't set groups if marked as public
|
||||
const groups = values.is_public ? [] : values.groups;
|
||||
|
||||
const submissionData = {
|
||||
const groups = values.is_public ? [] : values.selectedGroups;
|
||||
const submissionData: PersonaUpsertParameters = {
|
||||
...values,
|
||||
existing_prompt_id: existingPrompt?.id ?? null,
|
||||
is_default_persona: admin!,
|
||||
starter_messages: starterMessages,
|
||||
groups: values.is_public ? [] : values.selectedGroups,
|
||||
groups: groups,
|
||||
users: values.is_public
|
||||
? undefined
|
||||
: [
|
||||
@@ -514,25 +495,17 @@ export function AssistantEditor({
|
||||
num_chunks: numChunks,
|
||||
};
|
||||
|
||||
let promptResponse;
|
||||
let personaResponse;
|
||||
if (isUpdate) {
|
||||
[promptResponse, personaResponse] = await updatePersona({
|
||||
id: existingPersona.id,
|
||||
existingPromptId: existingPrompt?.id,
|
||||
...submissionData,
|
||||
});
|
||||
personaResponse = await updatePersona(
|
||||
existingPersona.id,
|
||||
submissionData
|
||||
);
|
||||
} else {
|
||||
[promptResponse, personaResponse] = await createPersona({
|
||||
...submissionData,
|
||||
is_default_persona: admin!,
|
||||
});
|
||||
personaResponse = await createPersona(submissionData);
|
||||
}
|
||||
|
||||
let error = null;
|
||||
if (!promptResponse.ok) {
|
||||
error = await promptResponse.text();
|
||||
}
|
||||
|
||||
if (!personaResponse) {
|
||||
error = "Failed to create Assistant - no response received";
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { FullLLMProvider } from "../configuration/llm/interfaces";
|
||||
import { Persona, StarterMessage } from "./interfaces";
|
||||
|
||||
interface PersonaCreationRequest {
|
||||
interface PersonaUpsertRequest {
|
||||
name: string;
|
||||
description: string;
|
||||
system_prompt: string;
|
||||
@@ -10,6 +10,36 @@ interface PersonaCreationRequest {
|
||||
num_chunks: number | null;
|
||||
include_citations: boolean;
|
||||
is_public: boolean;
|
||||
recency_bias: string;
|
||||
prompt_ids: number[];
|
||||
llm_filter_extraction: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
users?: string[];
|
||||
groups: number[];
|
||||
tool_ids: number[];
|
||||
icon_color: string | null;
|
||||
icon_shape: number | null;
|
||||
remove_image?: boolean;
|
||||
uploaded_image_id: string | null;
|
||||
search_start_date: Date | null;
|
||||
is_default_persona: boolean;
|
||||
display_priority: number | null;
|
||||
label_ids: number[] | null;
|
||||
}
|
||||
|
||||
export interface PersonaUpsertParameters {
|
||||
name: string;
|
||||
description: string;
|
||||
system_prompt: string;
|
||||
existing_prompt_id: number | null;
|
||||
task_prompt: string;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
include_citations: boolean;
|
||||
is_public: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
@@ -20,94 +50,10 @@ interface PersonaCreationRequest {
|
||||
icon_color: string | null;
|
||||
icon_shape: number | null;
|
||||
remove_image?: boolean;
|
||||
uploaded_image: File | null;
|
||||
search_start_date: Date | null;
|
||||
uploaded_image: File | null;
|
||||
is_default_persona: boolean;
|
||||
label_ids?: number[];
|
||||
}
|
||||
|
||||
interface PersonaUpdateRequest {
|
||||
id: number;
|
||||
existingPromptId: number | undefined;
|
||||
name: string;
|
||||
description: string;
|
||||
system_prompt: string;
|
||||
task_prompt: string;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
include_citations: boolean;
|
||||
is_public: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
users?: string[];
|
||||
groups: number[];
|
||||
tool_ids: number[];
|
||||
icon_color: string | null;
|
||||
icon_shape: number | null;
|
||||
remove_image: boolean;
|
||||
uploaded_image: File | null;
|
||||
search_start_date: Date | null;
|
||||
label_ids?: number[];
|
||||
}
|
||||
|
||||
function promptNameFromPersonaName(personaName: string) {
|
||||
return `default-prompt__${personaName}`;
|
||||
}
|
||||
|
||||
function createPrompt({
|
||||
personaName,
|
||||
systemPrompt,
|
||||
taskPrompt,
|
||||
includeCitations,
|
||||
}: {
|
||||
personaName: string;
|
||||
systemPrompt: string;
|
||||
taskPrompt: string;
|
||||
includeCitations: boolean;
|
||||
}) {
|
||||
return fetch("/api/prompt", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: promptNameFromPersonaName(personaName),
|
||||
description: `Default prompt for persona ${personaName}`,
|
||||
system_prompt: systemPrompt,
|
||||
task_prompt: taskPrompt,
|
||||
include_citations: includeCitations,
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
function updatePrompt({
|
||||
promptId,
|
||||
personaName,
|
||||
systemPrompt,
|
||||
taskPrompt,
|
||||
includeCitations,
|
||||
}: {
|
||||
promptId: number;
|
||||
personaName: string;
|
||||
systemPrompt: string;
|
||||
taskPrompt: string;
|
||||
includeCitations: boolean;
|
||||
}) {
|
||||
return fetch(`/api/prompt/${promptId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: promptNameFromPersonaName(personaName),
|
||||
description: `Default prompt for persona ${personaName}`,
|
||||
system_prompt: systemPrompt,
|
||||
task_prompt: taskPrompt,
|
||||
include_citations: includeCitations,
|
||||
}),
|
||||
});
|
||||
label_ids: number[] | null;
|
||||
}
|
||||
|
||||
export const createPersonaLabel = (name: string) => {
|
||||
@@ -144,56 +90,57 @@ export const updatePersonaLabel = (
|
||||
});
|
||||
};
|
||||
|
||||
function buildPersonaAPIBody(
|
||||
creationRequest: PersonaCreationRequest | PersonaUpdateRequest,
|
||||
promptId: number,
|
||||
function buildPersonaUpsertRequest(
|
||||
creationRequest: PersonaUpsertParameters,
|
||||
uploaded_image_id: string | null
|
||||
) {
|
||||
): PersonaUpsertRequest {
|
||||
const {
|
||||
name,
|
||||
description,
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
document_set_ids,
|
||||
num_chunks,
|
||||
llm_relevance_filter,
|
||||
include_citations,
|
||||
is_public,
|
||||
groups,
|
||||
existing_prompt_id,
|
||||
users,
|
||||
tool_ids,
|
||||
icon_color,
|
||||
icon_shape,
|
||||
remove_image,
|
||||
search_start_date,
|
||||
label_ids,
|
||||
} = creationRequest;
|
||||
|
||||
const is_default_persona =
|
||||
"is_default_persona" in creationRequest
|
||||
? creationRequest.is_default_persona
|
||||
: false;
|
||||
|
||||
return {
|
||||
name,
|
||||
description,
|
||||
num_chunks,
|
||||
llm_relevance_filter,
|
||||
llm_filter_extraction: false,
|
||||
is_public,
|
||||
recency_bias: "base_decay",
|
||||
prompt_ids: [promptId],
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
document_set_ids,
|
||||
llm_model_provider_override: creationRequest.llm_model_provider_override,
|
||||
llm_model_version_override: creationRequest.llm_model_version_override,
|
||||
starter_messages: creationRequest.starter_messages,
|
||||
users,
|
||||
num_chunks,
|
||||
include_citations,
|
||||
is_public,
|
||||
uploaded_image_id,
|
||||
groups,
|
||||
users,
|
||||
tool_ids,
|
||||
icon_color,
|
||||
icon_shape,
|
||||
uploaded_image_id,
|
||||
remove_image,
|
||||
search_start_date,
|
||||
is_default_persona,
|
||||
label_ids,
|
||||
is_default_persona: creationRequest.is_default_persona ?? false,
|
||||
recency_bias: "base_decay",
|
||||
prompt_ids: existing_prompt_id ? [existing_prompt_id] : [],
|
||||
llm_filter_extraction: false,
|
||||
llm_relevance_filter: creationRequest.llm_relevance_filter ?? null,
|
||||
llm_model_provider_override:
|
||||
creationRequest.llm_model_provider_override ?? null,
|
||||
llm_model_version_override:
|
||||
creationRequest.llm_model_version_override ?? null,
|
||||
starter_messages: creationRequest.starter_messages ?? null,
|
||||
display_priority: null,
|
||||
label_ids: creationRequest.label_ids ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -215,92 +162,52 @@ export async function uploadFile(file: File): Promise<string | null> {
|
||||
}
|
||||
|
||||
export async function createPersona(
|
||||
personaCreationRequest: PersonaCreationRequest
|
||||
): Promise<[Response, Response | null]> {
|
||||
// first create prompt
|
||||
const createPromptResponse = await createPrompt({
|
||||
personaName: personaCreationRequest.name,
|
||||
systemPrompt: personaCreationRequest.system_prompt,
|
||||
taskPrompt: personaCreationRequest.task_prompt,
|
||||
includeCitations: personaCreationRequest.include_citations,
|
||||
});
|
||||
const promptId = createPromptResponse.ok
|
||||
? (await createPromptResponse.json()).id
|
||||
: null;
|
||||
|
||||
personaUpsertParams: PersonaUpsertParameters
|
||||
): Promise<Response | null> {
|
||||
let fileId = null;
|
||||
if (personaCreationRequest.uploaded_image) {
|
||||
fileId = await uploadFile(personaCreationRequest.uploaded_image);
|
||||
if (personaUpsertParams.uploaded_image) {
|
||||
fileId = await uploadFile(personaUpsertParams.uploaded_image);
|
||||
if (!fileId) {
|
||||
return [createPromptResponse, null];
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const createPersonaResponse =
|
||||
promptId !== null
|
||||
? await fetch("/api/persona", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(
|
||||
buildPersonaAPIBody(personaCreationRequest, promptId, fileId)
|
||||
),
|
||||
})
|
||||
: null;
|
||||
const createPersonaResponse = await fetch("/api/persona", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(
|
||||
buildPersonaUpsertRequest(personaUpsertParams, fileId)
|
||||
),
|
||||
});
|
||||
|
||||
return [createPromptResponse, createPersonaResponse];
|
||||
return createPersonaResponse;
|
||||
}
|
||||
|
||||
export async function updatePersona(
|
||||
personaUpdateRequest: PersonaUpdateRequest
|
||||
): Promise<[Response, Response | null]> {
|
||||
const { id, existingPromptId } = personaUpdateRequest;
|
||||
|
||||
let promptResponse;
|
||||
let promptId: number | null = null;
|
||||
if (existingPromptId !== undefined) {
|
||||
promptResponse = await updatePrompt({
|
||||
promptId: existingPromptId,
|
||||
personaName: personaUpdateRequest.name,
|
||||
systemPrompt: personaUpdateRequest.system_prompt,
|
||||
taskPrompt: personaUpdateRequest.task_prompt,
|
||||
includeCitations: personaUpdateRequest.include_citations,
|
||||
});
|
||||
promptId = existingPromptId;
|
||||
} else {
|
||||
promptResponse = await createPrompt({
|
||||
personaName: personaUpdateRequest.name,
|
||||
systemPrompt: personaUpdateRequest.system_prompt,
|
||||
taskPrompt: personaUpdateRequest.task_prompt,
|
||||
includeCitations: personaUpdateRequest.include_citations,
|
||||
});
|
||||
promptId = promptResponse.ok
|
||||
? ((await promptResponse.json()).id as number)
|
||||
: null;
|
||||
}
|
||||
id: number,
|
||||
personaUpsertParams: PersonaUpsertParameters
|
||||
): Promise<Response | null> {
|
||||
let fileId = null;
|
||||
if (personaUpdateRequest.uploaded_image) {
|
||||
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
|
||||
if (personaUpsertParams.uploaded_image) {
|
||||
fileId = await uploadFile(personaUpsertParams.uploaded_image);
|
||||
if (!fileId) {
|
||||
return [promptResponse, null];
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const updatePersonaResponse =
|
||||
promptResponse.ok && promptId !== null
|
||||
? await fetch(`/api/persona/${id}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(
|
||||
buildPersonaAPIBody(personaUpdateRequest, promptId, fileId)
|
||||
),
|
||||
})
|
||||
: null;
|
||||
const updatePersonaResponse = await fetch(`/api/persona/${id}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(
|
||||
buildPersonaUpsertRequest(personaUpsertParams, fileId)
|
||||
),
|
||||
});
|
||||
|
||||
return [promptResponse, updatePersonaResponse];
|
||||
return updatePersonaResponse;
|
||||
}
|
||||
|
||||
export function deletePersona(personaId: number) {
|
||||
@@ -309,25 +216,6 @@ export function deletePersona(personaId: number) {
|
||||
});
|
||||
}
|
||||
|
||||
export function buildFinalPrompt(
|
||||
systemPrompt: string,
|
||||
taskPrompt: string,
|
||||
retrievalDisabled: boolean
|
||||
) {
|
||||
const queryString = Object.entries({
|
||||
system_prompt: systemPrompt,
|
||||
task_prompt: taskPrompt,
|
||||
retrieval_disabled: retrievalDisabled,
|
||||
})
|
||||
.map(
|
||||
([key, value]) =>
|
||||
`${encodeURIComponent(key)}=${encodeURIComponent(value)}`
|
||||
)
|
||||
.join("&");
|
||||
|
||||
return fetch(`/api/persona/utils/prompt-explorer?${queryString}`);
|
||||
}
|
||||
|
||||
function smallerNumberFirstComparator(a: number, b: number) {
|
||||
return a > b ? 1 : -1;
|
||||
}
|
||||
|
||||
@@ -64,7 +64,13 @@ export const SlackTokensForm = ({
|
||||
router.push(`/admin/bots/${encodeURIComponent(botId)}`);
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
let errorMsg = responseJson.detail || responseJson.message;
|
||||
|
||||
if (errorMsg.includes("Invalid bot token:")) {
|
||||
errorMsg = "Slack Bot Token is invalid";
|
||||
} else if (errorMsg.includes("Invalid app token:")) {
|
||||
errorMsg = "Slack App Token is invalid";
|
||||
}
|
||||
setPopup({
|
||||
message: isUpdate
|
||||
? `Error updating Slack Bot - ${errorMsg}`
|
||||
|
||||
@@ -1438,10 +1438,10 @@ export function ChatPage({
|
||||
}
|
||||
}
|
||||
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
|
||||
const updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
|
||||
@@ -9,8 +9,6 @@ import {
|
||||
} from "react-icons/fi";
|
||||
import { FeedbackType } from "../types";
|
||||
import React, {
|
||||
memo,
|
||||
ReactNode,
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
@@ -21,7 +19,10 @@ import React, {
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { SearchSummary } from "./SearchSummary";
|
||||
|
||||
import {
|
||||
markdownToHtml,
|
||||
getMarkdownForSelection,
|
||||
} from "@/app/chat/message/codeUtils";
|
||||
import { SkippedSearch } from "./SkippedSearch";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { CopyButton } from "@/components/CopyButton";
|
||||
@@ -37,12 +38,10 @@ import { DocumentPreview } from "../files/documents/DocumentPreview";
|
||||
import { InMessageImage } from "../files/images/InMessageImage";
|
||||
import { CodeBlock } from "./CodeBlock";
|
||||
import rehypePrism from "rehype-prism-plus";
|
||||
|
||||
import "prismjs/themes/prism-tomorrow.css";
|
||||
import "./custom-code-styles.css";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
|
||||
import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
@@ -68,7 +67,6 @@ import CsvContent from "../../../components/tools/CSVContent";
|
||||
import SourceCard, {
|
||||
SeeMoreBlock,
|
||||
} from "@/components/chat_search/sources/SourceCard";
|
||||
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
@@ -373,15 +371,28 @@ export const AIMessage = ({
|
||||
);
|
||||
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
if (typeof finalContent !== "string") {
|
||||
return finalContent;
|
||||
}
|
||||
|
||||
// Create a hidden div with the HTML content for copying
|
||||
const htmlContent = markdownToHtml(finalContent);
|
||||
|
||||
return (
|
||||
<ReactMarkdown
|
||||
className="prose max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
>
|
||||
{finalContent as string}
|
||||
</ReactMarkdown>
|
||||
<>
|
||||
<div
|
||||
style={{ position: "absolute", left: "-9999px" }}
|
||||
dangerouslySetInnerHTML={{ __html: htmlContent }}
|
||||
/>
|
||||
<ReactMarkdown
|
||||
className="prose max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
</>
|
||||
);
|
||||
}, [finalContent, markdownComponents]);
|
||||
|
||||
@@ -513,7 +524,68 @@ export const AIMessage = ({
|
||||
|
||||
{typeof content === "string" ? (
|
||||
<div className="overflow-x-visible max-w-content-max">
|
||||
{renderedMarkdown}
|
||||
<div
|
||||
contentEditable="true"
|
||||
suppressContentEditableWarning
|
||||
className="focus:outline-none cursor-text select-text"
|
||||
style={{
|
||||
MozUserModify: "read-only",
|
||||
WebkitUserModify: "read-only",
|
||||
}}
|
||||
onCopy={(e) => {
|
||||
e.preventDefault();
|
||||
const selection = window.getSelection();
|
||||
const selectedPlainText =
|
||||
selection?.toString() || "";
|
||||
if (!selectedPlainText) {
|
||||
// If no text is selected, copy the full content
|
||||
const contentStr =
|
||||
typeof content === "string"
|
||||
? content
|
||||
: (
|
||||
content as JSX.Element
|
||||
).props?.children?.toString() || "";
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[
|
||||
typeof content === "string"
|
||||
? markdownToHtml(content)
|
||||
: contentStr,
|
||||
],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob([contentStr], {
|
||||
type: "text/plain",
|
||||
}),
|
||||
});
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
return;
|
||||
}
|
||||
|
||||
const contentStr =
|
||||
typeof content === "string"
|
||||
? content
|
||||
: (
|
||||
content as JSX.Element
|
||||
).props?.children?.toString() || "";
|
||||
const markdownText = getMarkdownForSelection(
|
||||
contentStr,
|
||||
selectedPlainText
|
||||
);
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[markdownToHtml(markdownText)],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob([selectedPlainText], {
|
||||
type: "text/plain",
|
||||
}),
|
||||
});
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
}}
|
||||
>
|
||||
{renderedMarkdown}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
content
|
||||
@@ -559,7 +631,16 @@ export const AIMessage = ({
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton content={content.toString()} />
|
||||
<CopyButton
|
||||
content={
|
||||
typeof content === "string"
|
||||
? {
|
||||
html: markdownToHtml(content),
|
||||
plainText: content,
|
||||
}
|
||||
: content.toString()
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip showTick line content="Good response">
|
||||
<HoverableIcon
|
||||
@@ -644,7 +725,16 @@ export const AIMessage = ({
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton content={content.toString()} />
|
||||
<CopyButton
|
||||
content={
|
||||
typeof content === "string"
|
||||
? {
|
||||
html: markdownToHtml(content),
|
||||
plainText: content,
|
||||
}
|
||||
: content.toString()
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
|
||||
<CustomTooltip showTick line content="Good response">
|
||||
|
||||
176
web/src/app/chat/message/__tests__/codeUtils.test.ts
Normal file
176
web/src/app/chat/message/__tests__/codeUtils.test.ts
Normal file
@@ -0,0 +1,176 @@
|
||||
import { markdownToHtml, parseMarkdownToSegments } from "../codeUtils";
|
||||
|
||||
describe("markdownToHtml", () => {
|
||||
test("converts bold text with asterisks and underscores", () => {
|
||||
expect(markdownToHtml("This is **bold** text")).toBe(
|
||||
"<p>This is <strong>bold</strong> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is __bold__ text")).toBe(
|
||||
"<p>This is <strong>bold</strong> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("converts italic text with asterisks and underscores", () => {
|
||||
expect(markdownToHtml("This is *italic* text")).toBe(
|
||||
"<p>This is <em>italic</em> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _italic_ text")).toBe(
|
||||
"<p>This is <em>italic</em> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles mixed bold and italic", () => {
|
||||
expect(markdownToHtml("This is **bold** and *italic* text")).toBe(
|
||||
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is __bold__ and _italic_ text")).toBe(
|
||||
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles text with spaces and special characters", () => {
|
||||
expect(markdownToHtml("This is *as delicious and* tasty")).toBe(
|
||||
"<p>This is <em>as delicious and</em> tasty</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _as delicious and_ tasty")).toBe(
|
||||
"<p>This is <em>as delicious and</em> tasty</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles multi-paragraph text with italics", () => {
|
||||
const input =
|
||||
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it.";
|
||||
expect(markdownToHtml(input)).toBe(
|
||||
"<p>Sure! Here is a sentence with one italicized word:</p>\n<p>The cake was <em>delicious</em> and everyone enjoyed it.</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles malformed markdown without crashing", () => {
|
||||
expect(markdownToHtml("This is *malformed markdown")).toBe(
|
||||
"<p>This is *malformed markdown</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _also malformed")).toBe(
|
||||
"<p>This is _also malformed</p>"
|
||||
);
|
||||
expect(markdownToHtml("This has **unclosed bold")).toBe(
|
||||
"<p>This has **unclosed bold</p>"
|
||||
);
|
||||
expect(markdownToHtml("This has __unclosed bold")).toBe(
|
||||
"<p>This has __unclosed bold</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles empty or null input", () => {
|
||||
expect(markdownToHtml("")).toBe("");
|
||||
expect(markdownToHtml(" ")).toBe("");
|
||||
expect(markdownToHtml("\n")).toBe("");
|
||||
});
|
||||
|
||||
test("handles extremely long input without crashing", () => {
|
||||
const longText = "This is *italic* ".repeat(1000);
|
||||
expect(() => markdownToHtml(longText)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseMarkdownToSegments", () => {
|
||||
test("parses italic text with asterisks", () => {
|
||||
const segments = parseMarkdownToSegments("This is *italic* text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "italic", text: "italic", raw: "*italic*", length: 6 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses italic text with underscores", () => {
|
||||
const segments = parseMarkdownToSegments("This is _italic_ text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "italic", text: "italic", raw: "_italic_", length: 6 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses bold text with asterisks", () => {
|
||||
const segments = parseMarkdownToSegments("This is **bold** text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "bold", text: "bold", raw: "**bold**", length: 4 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses bold text with underscores", () => {
|
||||
const segments = parseMarkdownToSegments("This is __bold__ text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "bold", text: "bold", raw: "__bold__", length: 4 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses text with spaces and special characters in italics", () => {
|
||||
const segments = parseMarkdownToSegments(
|
||||
"The cake was _delicious_ and everyone enjoyed it."
|
||||
);
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "The cake was ", raw: "The cake was ", length: 13 },
|
||||
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
|
||||
{
|
||||
type: "text",
|
||||
text: " and everyone enjoyed it.",
|
||||
raw: " and everyone enjoyed it.",
|
||||
length: 25,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses multi-paragraph text with italics", () => {
|
||||
const segments = parseMarkdownToSegments(
|
||||
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it."
|
||||
);
|
||||
expect(segments).toEqual([
|
||||
{
|
||||
type: "text",
|
||||
text: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
|
||||
raw: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
|
||||
length: 65,
|
||||
},
|
||||
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
|
||||
{
|
||||
type: "text",
|
||||
text: " and everyone enjoyed it.",
|
||||
raw: " and everyone enjoyed it.",
|
||||
length: 25,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test("handles malformed markdown without crashing", () => {
|
||||
expect(() => parseMarkdownToSegments("This is *malformed")).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This is _also malformed")
|
||||
).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This has **unclosed bold")
|
||||
).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This has __unclosed bold")
|
||||
).not.toThrow();
|
||||
});
|
||||
|
||||
test("handles empty or null input", () => {
|
||||
expect(parseMarkdownToSegments("")).toEqual([]);
|
||||
expect(parseMarkdownToSegments(" ")).toEqual([
|
||||
{ type: "text", text: " ", raw: " ", length: 1 },
|
||||
]);
|
||||
expect(parseMarkdownToSegments("\n")).toEqual([
|
||||
{ type: "text", text: "\n", raw: "\n", length: 1 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("handles extremely long input without crashing", () => {
|
||||
const longText = "This is *italic* ".repeat(1000);
|
||||
expect(() => parseMarkdownToSegments(longText)).not.toThrow();
|
||||
});
|
||||
});
|
||||
@@ -82,3 +82,252 @@ export const preprocessLaTeX = (content: string) => {
|
||||
|
||||
return inlineProcessedContent;
|
||||
};
|
||||
|
||||
export const markdownToHtml = (content: string): string => {
|
||||
if (!content || !content.trim()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Basic markdown to HTML conversion for common patterns
|
||||
const processedContent = content
|
||||
.replace(/(\*\*|__)((?:(?!\1).)*?)\1/g, "<strong>$2</strong>") // Bold with ** or __, non-greedy and no nesting
|
||||
.replace(/(\*|_)([^*_\n]+?)\1(?!\*|_)/g, "<em>$2</em>"); // Italic with * or _
|
||||
|
||||
// Handle code blocks and links
|
||||
const withCodeAndLinks = processedContent
|
||||
.replace(/`([^`]+)`/g, "<code>$1</code>") // Inline code
|
||||
.replace(
|
||||
/```(\w*)\n([\s\S]*?)```/g,
|
||||
(_, lang, code) =>
|
||||
`<pre><code class="language-${lang}">${code.trim()}</code></pre>`
|
||||
) // Code blocks
|
||||
.replace(/\[([^\]]+)\]\(([^)]+)\)/g, '<a href="$2">$1</a>'); // Links
|
||||
|
||||
// Handle paragraphs
|
||||
return withCodeAndLinks
|
||||
.split(/\n\n+/)
|
||||
.map((para) => para.trim())
|
||||
.filter((para) => para.length > 0)
|
||||
.map((para) => `<p>${para}</p>`)
|
||||
.join("\n");
|
||||
};
|
||||
|
||||
interface MarkdownSegment {
|
||||
type: "text" | "link" | "code" | "bold" | "italic" | "codeblock";
|
||||
text: string; // The visible/plain text
|
||||
raw: string; // The raw markdown including syntax
|
||||
length: number; // Length of the visible text
|
||||
}
|
||||
|
||||
export function parseMarkdownToSegments(markdown: string): MarkdownSegment[] {
|
||||
if (!markdown) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const segments: MarkdownSegment[] = [];
|
||||
let currentIndex = 0;
|
||||
const maxIterations = markdown.length * 2; // Prevent infinite loops
|
||||
let iterations = 0;
|
||||
|
||||
while (currentIndex < markdown.length && iterations < maxIterations) {
|
||||
iterations++;
|
||||
let matched = false;
|
||||
|
||||
// Check for code blocks first (they take precedence)
|
||||
const codeBlockMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^```(\w*)\n([\s\S]*?)```/);
|
||||
if (codeBlockMatch && codeBlockMatch[0]) {
|
||||
const [fullMatch, , code] = codeBlockMatch;
|
||||
segments.push({
|
||||
type: "codeblock",
|
||||
text: code || "",
|
||||
raw: fullMatch,
|
||||
length: (code || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for inline code
|
||||
const inlineCodeMatch = markdown.slice(currentIndex).match(/^`([^`]+)`/);
|
||||
if (inlineCodeMatch && inlineCodeMatch[0]) {
|
||||
const [fullMatch, code] = inlineCodeMatch;
|
||||
segments.push({
|
||||
type: "code",
|
||||
text: code || "",
|
||||
raw: fullMatch,
|
||||
length: (code || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for links
|
||||
const linkMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^\[([^\]]+)\]\(([^)]+)\)/);
|
||||
if (linkMatch && linkMatch[0]) {
|
||||
const [fullMatch, text] = linkMatch;
|
||||
segments.push({
|
||||
type: "link",
|
||||
text: text || "",
|
||||
raw: fullMatch,
|
||||
length: (text || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for bold
|
||||
const boldMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^(\*\*|__)([^*_\n]*?)\1/);
|
||||
if (boldMatch && boldMatch[0]) {
|
||||
const [fullMatch, , text] = boldMatch;
|
||||
segments.push({
|
||||
type: "bold",
|
||||
text: text || "",
|
||||
raw: fullMatch,
|
||||
length: (text || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for italic
|
||||
const italicMatch = markdown
|
||||
.slice(currentIndex)
|
||||
.match(/^(\*|_)([^*_\n]+?)\1(?!\*|_)/);
|
||||
if (italicMatch && italicMatch[0]) {
|
||||
const [fullMatch, , text] = italicMatch;
|
||||
segments.push({
|
||||
type: "italic",
|
||||
text: text || "",
|
||||
raw: fullMatch,
|
||||
length: (text || "").length,
|
||||
});
|
||||
currentIndex += fullMatch.length;
|
||||
matched = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// If no matches were found, handle regular text
|
||||
if (!matched) {
|
||||
let nextSpecialChar = markdown.slice(currentIndex).search(/[`\[*_]/);
|
||||
if (nextSpecialChar === -1) {
|
||||
// No more special characters, add the rest as text
|
||||
const text = markdown.slice(currentIndex);
|
||||
if (text) {
|
||||
segments.push({
|
||||
type: "text",
|
||||
text: text,
|
||||
raw: text,
|
||||
length: text.length,
|
||||
});
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Add the text up to the next special character
|
||||
const text = markdown.slice(
|
||||
currentIndex,
|
||||
currentIndex + nextSpecialChar
|
||||
);
|
||||
if (text) {
|
||||
segments.push({
|
||||
type: "text",
|
||||
text: text,
|
||||
raw: text,
|
||||
length: text.length,
|
||||
});
|
||||
}
|
||||
currentIndex += nextSpecialChar;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return segments;
|
||||
}
|
||||
|
||||
export function getMarkdownForSelection(
|
||||
content: string,
|
||||
selectedText: string
|
||||
): string {
|
||||
const segments = parseMarkdownToSegments(content);
|
||||
|
||||
// Build plain text and create mapping to markdown segments
|
||||
let plainText = "";
|
||||
const markdownPieces: string[] = [];
|
||||
let currentPlainIndex = 0;
|
||||
|
||||
segments.forEach((segment) => {
|
||||
plainText += segment.text;
|
||||
markdownPieces.push(segment.raw);
|
||||
currentPlainIndex += segment.length;
|
||||
});
|
||||
|
||||
// Find the selection in the plain text
|
||||
const startIndex = plainText.indexOf(selectedText);
|
||||
if (startIndex === -1) {
|
||||
return selectedText;
|
||||
}
|
||||
|
||||
const endIndex = startIndex + selectedText.length;
|
||||
|
||||
// Find which segments the selection spans
|
||||
let currentIndex = 0;
|
||||
let result = "";
|
||||
let selectionStart = startIndex;
|
||||
let selectionEnd = endIndex;
|
||||
|
||||
segments.forEach((segment) => {
|
||||
const segmentStart = currentIndex;
|
||||
const segmentEnd = segmentStart + segment.length;
|
||||
|
||||
// Check if this segment overlaps with the selection
|
||||
if (segmentEnd > selectionStart && segmentStart < selectionEnd) {
|
||||
// Calculate how much of this segment to include
|
||||
const overlapStart = Math.max(0, selectionStart - segmentStart);
|
||||
const overlapEnd = Math.min(segment.length, selectionEnd - segmentStart);
|
||||
|
||||
if (segment.type === "text") {
|
||||
const textPortion = segment.text.slice(overlapStart, overlapEnd);
|
||||
result += textPortion;
|
||||
} else {
|
||||
// For markdown elements, wrap just the selected portion with the appropriate markdown
|
||||
const selectedPortion = segment.text.slice(overlapStart, overlapEnd);
|
||||
|
||||
switch (segment.type) {
|
||||
case "bold":
|
||||
result += `**${selectedPortion}**`;
|
||||
break;
|
||||
case "italic":
|
||||
result += `*${selectedPortion}*`;
|
||||
break;
|
||||
case "code":
|
||||
result += `\`${selectedPortion}\``;
|
||||
break;
|
||||
case "link":
|
||||
// For links, we need to preserve the URL if it exists in the raw markdown
|
||||
const urlMatch = segment.raw.match(/\]\((.*?)\)/);
|
||||
const url = urlMatch ? urlMatch[1] : "";
|
||||
result += `[${selectedPortion}](${url})`;
|
||||
break;
|
||||
case "codeblock":
|
||||
result += `\`\`\`\n${selectedPortion}\n\`\`\``;
|
||||
break;
|
||||
default:
|
||||
result += selectedPortion;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
currentIndex += segment.length;
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import { cn } from "@/lib/utils";
|
||||
import { CalendarIcon } from "lucide-react";
|
||||
import { format } from "date-fns";
|
||||
import { getXDaysAgo } from "./dateUtils";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
|
||||
export const THIRTY_DAYS = "30d";
|
||||
|
||||
@@ -84,8 +83,16 @@ export const DateRangeSelector = memo(function DateRangeSelector({
|
||||
defaultMonth={value?.from}
|
||||
selected={value}
|
||||
onSelect={(range) => {
|
||||
if (range?.from && range?.to) {
|
||||
onValueChange({ from: range.from, to: range.to });
|
||||
if (range?.from) {
|
||||
if (range.to) {
|
||||
// Normal range selection when initialized with a range
|
||||
onValueChange({ from: range.from, to: range.to });
|
||||
} else {
|
||||
// Single date selection when initilized without a range
|
||||
const to = new Date(range.from);
|
||||
const from = new Date(to.setDate(to.getDate() - 1));
|
||||
onValueChange({ from, to });
|
||||
}
|
||||
}
|
||||
}}
|
||||
numberOfMonths={2}
|
||||
|
||||
@@ -65,27 +65,6 @@ export const useOnyxBotAnalytics = (timeRange: DateRangePickerValue) => {
|
||||
};
|
||||
};
|
||||
|
||||
export const useQueryHistory = ({
|
||||
selectedFeedbackType,
|
||||
timeRange,
|
||||
}: {
|
||||
selectedFeedbackType: Feedback | null;
|
||||
timeRange: DateRange;
|
||||
}) => {
|
||||
const url = buildApiPath("/api/admin/chat-session-history", {
|
||||
feedback_type: selectedFeedbackType,
|
||||
start: convertDateToStartOfDay(timeRange?.from)?.toISOString(),
|
||||
end: convertDateToEndOfDay(timeRange?.to)?.toISOString(),
|
||||
});
|
||||
|
||||
const swrResponse = useSWR<ChatSessionMinimal[]>(url, errorHandlingFetcher);
|
||||
|
||||
return {
|
||||
...swrResponse,
|
||||
refreshQueryHistory: () => mutate(url),
|
||||
};
|
||||
};
|
||||
|
||||
export function getDatesList(startDate: Date): string[] {
|
||||
const datesList: string[] = [];
|
||||
const endDate = new Date(); // current date
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { useQueryHistory, useTimeRange } from "../lib";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import {
|
||||
Table,
|
||||
@@ -20,8 +19,8 @@ import {
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { ChatSessionMinimal } from "../usage/types";
|
||||
import { timestampToReadableDate } from "@/lib/dateUtils";
|
||||
import { FiFrown, FiMinus, FiSmile } from "react-icons/fi";
|
||||
import { useCallback, useState } from "react";
|
||||
import { FiFrown, FiMinus, FiSmile, FiMeh } from "react-icons/fi";
|
||||
import { useCallback, useState, useMemo } from "react";
|
||||
import { Feedback } from "@/lib/types";
|
||||
import { DateRange, DateRangeSelector } from "../DateRangeSelector";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
@@ -29,8 +28,11 @@ import Link from "next/link";
|
||||
import { FeedbackBadge } from "./FeedbackBadge";
|
||||
import { DownloadAsCSV } from "./DownloadAsCSV";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
|
||||
const NUM_IN_PAGE = 20;
|
||||
const ITEMS_PER_PAGE = 20;
|
||||
const PAGES_PER_BATCH = 2;
|
||||
|
||||
function QueryHistoryTableRow({
|
||||
chatSessionMinimal,
|
||||
@@ -108,6 +110,12 @@ function SelectFeedbackType({
|
||||
<span>Dislike</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
<SelectItem value="mixed">
|
||||
<div className="flex items-center gap-2">
|
||||
<FiMeh className="h-4 w-4" />
|
||||
<span>Mixed</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
@@ -116,31 +124,55 @@ function SelectFeedbackType({
|
||||
}
|
||||
|
||||
export function QueryHistoryTable() {
|
||||
const [selectedFeedbackType, setSelectedFeedbackType] = useState<
|
||||
Feedback | "all"
|
||||
>("all");
|
||||
const [timeRange, setTimeRange] = useTimeRange();
|
||||
const [dateRange, setDateRange] = useState<DateRange>(undefined);
|
||||
const [filters, setFilters] = useState<{
|
||||
feedback_type?: Feedback | "all";
|
||||
start_time?: string;
|
||||
end_time?: string;
|
||||
}>({});
|
||||
|
||||
const { data: chatSessionData } = useQueryHistory({
|
||||
selectedFeedbackType:
|
||||
selectedFeedbackType === "all" ? null : selectedFeedbackType,
|
||||
timeRange,
|
||||
const {
|
||||
currentPageData: chatSessionData,
|
||||
isLoading,
|
||||
error,
|
||||
currentPage,
|
||||
totalPages,
|
||||
goToPage,
|
||||
refresh,
|
||||
} = usePaginatedFetch<ChatSessionMinimal>({
|
||||
itemsPerPage: ITEMS_PER_PAGE,
|
||||
pagesPerBatch: PAGES_PER_BATCH,
|
||||
endpoint: "/api/admin/chat-session-history",
|
||||
filter: filters,
|
||||
});
|
||||
|
||||
const [page, setPage] = useState(1);
|
||||
const onTimeRangeChange = useCallback((value: DateRange) => {
|
||||
setDateRange(value);
|
||||
|
||||
const onTimeRangeChange = useCallback(
|
||||
(value: DateRange) => {
|
||||
if (value) {
|
||||
setTimeRange((prevTimeRange) => ({
|
||||
...prevTimeRange,
|
||||
from: new Date(value.from),
|
||||
to: new Date(value.to),
|
||||
}));
|
||||
}
|
||||
},
|
||||
[setTimeRange]
|
||||
);
|
||||
if (value?.from && value?.to) {
|
||||
setFilters((prev) => ({
|
||||
...prev,
|
||||
start_time: value.from.toISOString(),
|
||||
end_time: value.to.toISOString(),
|
||||
}));
|
||||
} else {
|
||||
setFilters((prev) => {
|
||||
const newFilters = { ...prev };
|
||||
delete newFilters.start_time;
|
||||
delete newFilters.end_time;
|
||||
return newFilters;
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Error fetching query history"
|
||||
errorMsg={error?.message}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
@@ -148,12 +180,22 @@ export function QueryHistoryTable() {
|
||||
<div className="flex">
|
||||
<div className="gap-y-3 flex flex-col">
|
||||
<SelectFeedbackType
|
||||
value={selectedFeedbackType || "all"}
|
||||
onValueChange={setSelectedFeedbackType}
|
||||
value={filters.feedback_type || "all"}
|
||||
onValueChange={(value) => {
|
||||
setFilters((prev) => {
|
||||
const newFilters = { ...prev };
|
||||
if (value === "all") {
|
||||
delete newFilters.feedback_type;
|
||||
} else {
|
||||
newFilters.feedback_type = value;
|
||||
}
|
||||
return newFilters;
|
||||
});
|
||||
}}
|
||||
/>
|
||||
|
||||
<DateRangeSelector
|
||||
value={timeRange}
|
||||
value={dateRange}
|
||||
onValueChange={onTimeRangeChange}
|
||||
/>
|
||||
</div>
|
||||
@@ -172,33 +214,33 @@ export function QueryHistoryTable() {
|
||||
<TableHead>Date</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{chatSessionData &&
|
||||
chatSessionData
|
||||
.slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page)
|
||||
.map((chatSessionMinimal) => (
|
||||
<QueryHistoryTableRow
|
||||
key={chatSessionMinimal.id}
|
||||
chatSessionMinimal={chatSessionMinimal}
|
||||
/>
|
||||
))}
|
||||
</TableBody>
|
||||
{isLoading ? (
|
||||
<TableBody>
|
||||
<TableRow>
|
||||
<TableCell colSpan={6} className="text-center">
|
||||
<ThreeDotsLoader />
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
) : (
|
||||
<TableBody>
|
||||
{chatSessionData?.map((chatSessionMinimal) => (
|
||||
<QueryHistoryTableRow
|
||||
key={chatSessionMinimal.id}
|
||||
chatSessionMinimal={chatSessionMinimal}
|
||||
/>
|
||||
))}
|
||||
</TableBody>
|
||||
)}
|
||||
</Table>
|
||||
|
||||
{chatSessionData && (
|
||||
<div className="mt-3 flex">
|
||||
<div className="mx-auto">
|
||||
<PageSelector
|
||||
totalPages={Math.ceil(chatSessionData.length / NUM_IN_PAGE)}
|
||||
currentPage={page}
|
||||
onPageChange={(newPage) => {
|
||||
setPage(newPage);
|
||||
window.scrollTo({
|
||||
top: 0,
|
||||
left: 0,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}}
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={goToPage}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -106,9 +106,7 @@ export default function QueryPage(props: { params: Promise<{ id: string }> }) {
|
||||
|
||||
<div className="flex flex-col">
|
||||
{chatSessionSnapshot.messages.map((message) => {
|
||||
return (
|
||||
<MessageDisplay key={message.time_created} message={message} />
|
||||
);
|
||||
return <MessageDisplay key={message.id} message={message} />;
|
||||
})}
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
@@ -25,6 +25,7 @@ export interface AbridgedSearchDoc {
|
||||
}
|
||||
|
||||
export interface MessageSnapshot {
|
||||
id: number;
|
||||
message: string;
|
||||
message_type: "user" | "assistant";
|
||||
documents: AbridgedSearchDoc[];
|
||||
|
||||
@@ -1,23 +1,45 @@
|
||||
import { useState } from "react";
|
||||
import { FiCheck, FiCopy } from "react-icons/fi";
|
||||
import { Hoverable, HoverableIcon } from "./Hoverable";
|
||||
import { HoverableIcon } from "./Hoverable";
|
||||
import { CheckmarkIcon, CopyMessageIcon } from "./icons/icons";
|
||||
|
||||
export function CopyButton({
|
||||
content,
|
||||
onClick,
|
||||
}: {
|
||||
content?: string;
|
||||
content?: string | { html: string; plainText: string };
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
const [copyClicked, setCopyClicked] = useState(false);
|
||||
|
||||
const copyToClipboard = async (
|
||||
content: string | { html: string; plainText: string }
|
||||
) => {
|
||||
try {
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[typeof content === "string" ? content : content.html],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob(
|
||||
[typeof content === "string" ? content : content.plainText],
|
||||
{ type: "text/plain" }
|
||||
),
|
||||
});
|
||||
await navigator.clipboard.write([clipboardItem]);
|
||||
} catch (err) {
|
||||
// Fallback to basic text copy if HTML copy fails
|
||||
await navigator.clipboard.writeText(
|
||||
typeof content === "string" ? content : content.plainText
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<HoverableIcon
|
||||
icon={copyClicked ? <CheckmarkIcon /> : <CopyMessageIcon />}
|
||||
onClick={() => {
|
||||
if (content) {
|
||||
navigator.clipboard.writeText(content.toString());
|
||||
copyToClipboard(content);
|
||||
}
|
||||
onClick && onClick();
|
||||
|
||||
|
||||
@@ -16,17 +16,19 @@ interface TextViewProps {
|
||||
presentingDocument: OnyxDocument;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function TextView({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
}: TextViewProps) {
|
||||
const [zoom, setZoom] = useState(100);
|
||||
const [fileContent, setFileContent] = useState<string>("");
|
||||
const [fileUrl, setFileUrl] = useState<string>("");
|
||||
const [fileName, setFileName] = useState<string>("");
|
||||
const [fileContent, setFileContent] = useState("");
|
||||
const [fileUrl, setFileUrl] = useState("");
|
||||
const [fileName, setFileName] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [fileType, setFileType] = useState<string>("application/octet-stream");
|
||||
const [fileType, setFileType] = useState("application/octet-stream");
|
||||
|
||||
// Detect if a given MIME type is one of the recognized markdown formats
|
||||
const isMarkdownFormat = (mimeType: string): boolean => {
|
||||
const markdownFormats = [
|
||||
"text/markdown",
|
||||
@@ -38,6 +40,7 @@ export default function TextView({
|
||||
return markdownFormats.some((format) => mimeType.startsWith(format));
|
||||
};
|
||||
|
||||
// Detect if a given MIME type can be rendered in an <iframe>
|
||||
const isSupportedIframeFormat = (mimeType: string): boolean => {
|
||||
const supportedFormats = [
|
||||
"application/pdf",
|
||||
@@ -52,6 +55,7 @@ export default function TextView({
|
||||
const fetchFile = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
const fileId = presentingDocument.document_id.split("__")[1];
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/chat/file/${encodeURIComponent(fileId)}`,
|
||||
@@ -62,18 +66,33 @@ export default function TextView({
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
setFileUrl(url);
|
||||
setFileName(presentingDocument.semantic_identifier || "document");
|
||||
const contentType =
|
||||
|
||||
const originalFileName =
|
||||
presentingDocument.semantic_identifier || "document";
|
||||
setFileName(originalFileName);
|
||||
|
||||
let contentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
|
||||
// If it's octet-stream but file name suggests a markdown extension, override and attempt to read as markdown
|
||||
if (
|
||||
contentType === "application/octet-stream" &&
|
||||
(originalFileName.toLowerCase().endsWith(".md") ||
|
||||
originalFileName.toLowerCase().endsWith(".markdown"))
|
||||
) {
|
||||
contentType = "text/markdown";
|
||||
}
|
||||
setFileType(contentType);
|
||||
|
||||
if (isMarkdownFormat(blob.type)) {
|
||||
// If the final content type looks like markdown, read its text
|
||||
if (isMarkdownFormat(contentType)) {
|
||||
const text = await blob.text();
|
||||
setFileContent(text);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching file:", error);
|
||||
} finally {
|
||||
// Keep the slight delay for a smoother loading experience
|
||||
setTimeout(() => {
|
||||
setIsLoading(false);
|
||||
}, 1000);
|
||||
@@ -97,11 +116,8 @@ export default function TextView({
|
||||
const handleZoomOut = () => setZoom((prev) => Math.max(prev - 25, 100));
|
||||
|
||||
return (
|
||||
<Dialog open={true} onOpenChange={onClose}>
|
||||
<DialogContent
|
||||
hideCloseIcon
|
||||
className="max-w-5xl w-[90vw] flex flex-col justify-between gap-y-0 h-full max-h-[80vh] p-0"
|
||||
>
|
||||
<Dialog open onOpenChange={onClose}>
|
||||
<DialogContent className="max-w-5xl w-[90vw] flex flex-col justify-between gap-y-0 h-full max-h-[80vh] p-0">
|
||||
<DialogHeader className="px-4 mb-0 pt-2 pb-3 flex flex-row items-center justify-between border-b">
|
||||
<DialogTitle className="text-lg font-medium truncate">
|
||||
{fileName}
|
||||
@@ -120,12 +136,13 @@ export default function TextView({
|
||||
<Download className="h-4 w-4" />
|
||||
<span className="sr-only">Download</span>
|
||||
</Button>
|
||||
<Button variant="ghost" size="icon" onClick={() => onClose()}>
|
||||
<Button variant="ghost" size="icon" onClick={onClose}>
|
||||
<XIcon className="h-4 w-4" />
|
||||
<span className="sr-only">Close</span>
|
||||
</Button>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="mt-0 rounded-b-lg flex-1 overflow-hidden">
|
||||
<div className="flex items-center justify-center w-full h-full">
|
||||
{isLoading ? (
|
||||
@@ -137,7 +154,7 @@ export default function TextView({
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
className={`w-full h-full transform origin-center transition-transform duration-300 ease-in-out`}
|
||||
className="w-full h-full transform origin-center transition-transform duration-300 ease-in-out"
|
||||
style={{ transform: `scale(${zoom / 100})` }}
|
||||
>
|
||||
{isSupportedIframeFormat(fileType) ? (
|
||||
@@ -150,7 +167,7 @@ export default function TextView({
|
||||
<div className="w-full h-full p-6 overflow-y-scroll overflow-x-hidden">
|
||||
<MinimalMarkdown
|
||||
content={fileContent}
|
||||
className="w-full pb-4 h-full text-lg text-wrap break-words"
|
||||
className="w-full pb-4 h-full text-lg break-words"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
|
||||
@@ -5,12 +5,14 @@ import {
|
||||
AcceptedUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
} from "@/lib/types";
|
||||
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
type PaginatedType =
|
||||
| IndexAttemptSnapshot
|
||||
| AcceptedUserSnapshot
|
||||
| InvitedUserSnapshot;
|
||||
| InvitedUserSnapshot
|
||||
| ChatSessionMinimal;
|
||||
|
||||
interface PaginatedApiResponse<T extends PaginatedType> {
|
||||
items: T[];
|
||||
@@ -22,7 +24,7 @@ interface PaginationConfig {
|
||||
pagesPerBatch: number;
|
||||
endpoint: string;
|
||||
query?: string;
|
||||
filter?: Record<string, string | boolean | number | string[]>;
|
||||
filter?: Record<string, string | boolean | number | string[] | Date>;
|
||||
refreshIntervalInMs?: number;
|
||||
}
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ export type ValidStatuses =
|
||||
| "in_progress"
|
||||
| "not_started";
|
||||
export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE";
|
||||
export type Feedback = "like" | "dislike";
|
||||
export type Feedback = "like" | "dislike" | "mixed";
|
||||
export type AccessType = "public" | "private" | "sync";
|
||||
export type SessionType = "Chat" | "Search" | "Slack";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user