Compare commits

..

13 Commits

Author SHA1 Message Date
pablodanswer
25b38212e9 nit 2025-01-19 09:50:35 -08:00
pablodanswer
3096b0b2a7 add linear check 2025-01-19 09:49:26 -08:00
Chris Weaver
342bb9f685 Fix document counts (#3671)
* Various fixes/improvements to document counting

* Add new column + index

* Avoid double scan

* comment fixes

* Fix revision history

* Fix IT

* Fix IT

* Fix migration

* Rebase
2025-01-19 05:36:07 +00:00
hagen-danswer
b25668c83a fixed group sync to account for changes in drive permissions (#3666)
* fixed group sync to account for changes in drive permissions

* mypy

* addressed

* reeeeeeeee
2025-01-19 00:08:50 +00:00
Weves
a72bd31f5d Small background telemetry fix 2025-01-18 16:19:28 -08:00
hagen-danswer
896e716d02 query history pagination tests (#3700)
* dummy pr

* Update prompts.yaml

* fixed tests and added query history pagination test

* done

* fixed

* utils!
2025-01-18 21:28:03 +00:00
pablonyx
eec3ce8162 Markdown rendering (#3698)
* nit

* update comment
2025-01-18 12:12:19 -08:00
pablonyx
2761a837c6 quick nit for no-longer living files (#3702) 2025-01-18 11:09:34 -08:00
hagen-danswer
da43abe644 Made copy button and cmd+c work for cmd+v and cmd+shift+v (#3693)
* Made copy button and cmd+c work for cmd+v and cmd+shift+v

* made sub selections work as well

* ok it works

* fixed npm run build

* im not from earth

* added logging

* more logging

* bye logs

* should work now

* whoops

* added stuff

* made it robust

* ctrl shift v behavior
2025-01-18 10:34:32 -08:00
skylares
af953ff8a3 Paginate Query History table (#3592)
* Add pagination for query history table

* Fix method name

* Fix mypy
2025-01-17 15:31:42 -08:00
rkuo-danswer
6fc52c81ab Bugfix/beat redux (#3639)
* WIP

* WIP

* try spinning out check for indexing into a system task

* check for the correct delimiter

* use constants

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-17 20:59:43 +00:00
hagen-danswer
1ad2128b2a Combined Persona and Prompt API (#3690)
* Combined Persona and Prompt API

* quality

* added tests

* consolidated models and got rid of redundant fields

* tenant appreciation day

* reverted default
2025-01-17 20:21:20 +00:00
Kaveen Jayamanna
880c42ad41 Validating slackbot tokens (#3695)
* added missing dependency, missing api key placeholder, updated docs

* Apply black formatting and validate bot token functionality

* acknowledging black formatting

* added the validation to update tokens as well

* Made the token validation errors looks nicer

* getting rif of duplicate dependency
2025-01-17 11:50:22 -08:00
89 changed files with 6425 additions and 2293 deletions

View File

@@ -1,11 +1,15 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] I have included a link to a Linear ticket in my description.
- [ ] [Optional] Override Linear Check

29
.github/workflows/pr-linear-check.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Ensure PR references Linear
on:
pull_request:
types: [opened, edited, reopened, synchronize]
jobs:
linear-check:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
run: |
PR_BODY="${{ github.event.pull_request.body }}"
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."
exit 0
fi
# Looking for a checked override: "[x] Override Linear Check"
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
echo "Override box is checked. Check passed."
exit 0
fi
# Otherwise, fail the run
echo "No Linear link or override found in the PR description."
exit 1

View File

@@ -29,6 +29,7 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
OPENAI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o

View File

@@ -17,9 +17,10 @@ Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
4. CD into web, run "npm i" followed by npm run dev.
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
7. Use the debug toolbar to step through code, inspect variables, etc.
## Features

View File

@@ -0,0 +1,48 @@
"""Add has_been_indexed to DocumentByConnectorCredentialPair
Revision ID: c7bf5721733e
Revises: fec3db967bf7
Create Date: 2025-01-13 12:39:05.831693
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7bf5721733e"
down_revision = "027381bce97c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# assume all existing rows have been indexed, no better approach
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
)
op.alter_column(
"document_by_connector_credential_pair",
"has_been_indexed",
nullable=False,
)
# Add index to optimize get_document_counts_for_cc_pairs query pattern
op.create_index(
"idx_document_cc_pair_counts",
"document_by_connector_credential_pair",
["connector_id", "credential_id", "has_been_indexed"],
unique=False,
)
def downgrade() -> None:
# Remove the index first before removing the column
op.drop_index(
"idx_document_cc_pair_counts",
table_name="document_by_connector_credential_pair",
)
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")

View File

@@ -1,6 +1,9 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
)
@@ -8,7 +11,7 @@ from onyx.configs.constants import OnyxCeleryTask
ee_tasks_to_schedule = [
{
"name": "autogenerate_usage_report",
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
@@ -20,5 +23,9 @@ ee_tasks_to_schedule = [
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_tasks_to_schedule

View File

@@ -1,27 +1,135 @@
import datetime
from typing import Literal
from collections.abc import Sequence
from datetime import datetime
from sqlalchemy import asc
from sqlalchemy import BinaryExpression
from sqlalchemy import ColumnElement
from sqlalchemy import desc
from sqlalchemy import distinct
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import case
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import UnaryExpression
from onyx.configs.constants import QAFeedbackType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
SortByOptions = Literal["time_sent"]
def _build_filter_conditions(
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> list[ColumnElement]:
"""
Helper function to build all filter conditions for chat sessions.
Filters by start and end time, feedback type, and any sessions without messages.
start_time: Date from which to filter
end_time: Date to which to filter
feedback_filter: Feedback type to filter by
Returns: List of filter conditions
"""
conditions = []
if start_time is not None:
conditions.append(ChatSession.time_created >= start_time)
if end_time is not None:
conditions.append(ChatSession.time_created <= end_time)
if feedback_filter is not None:
feedback_subq = (
select(ChatMessage.chat_session_id)
.join(ChatMessageFeedback)
.group_by(ChatMessage.chat_session_id)
.having(
case(
(
case(
{literal(feedback_filter == QAFeedbackType.LIKE): True},
else_=False,
),
func.bool_and(ChatMessageFeedback.is_positive),
),
(
case(
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
else_=False,
),
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
),
else_=func.bool_or(ChatMessageFeedback.is_positive)
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
)
)
)
conditions.append(ChatSession.id.in_(feedback_subq))
return conditions
def get_total_filtered_chat_sessions_count(
db_session: Session,
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> int:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
stmt = (
select(func.count(distinct(ChatSession.id)))
.select_from(ChatSession)
.filter(*conditions)
)
return db_session.scalar(stmt) or 0
def get_page_of_chat_sessions(
start_time: datetime | None,
end_time: datetime | None,
db_session: Session,
page_num: int,
page_size: int,
feedback_filter: QAFeedbackType | None = None,
) -> Sequence[ChatSession]:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id, ChatSession.time_created)
.filter(*conditions)
.order_by(ChatSession.id, desc(ChatSession.time_created))
.distinct(ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
)
stmt = (
select(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options(
joinedload(ChatSession.user),
joinedload(ChatSession.persona),
contains_eager(ChatSession.messages).joinedload(
ChatMessage.chat_message_feedbacks
),
)
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
)
return db_session.scalars(stmt).unique().all()
def fetch_chat_sessions_eagerly_by_time(
start: datetime.datetime,
end: datetime.datetime,
start: datetime,
end: datetime,
db_session: Session,
limit: int | None = 500,
initial_time: datetime.datetime | None = None,
initial_time: datetime | None = None,
) -> list[ChatSession]:
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)

View File

@@ -120,9 +120,12 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_emails,
external_user_group_ids=group_ids,
is_public=public,
)

View File

@@ -1,16 +1,127 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import AdminService
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
) -> dict[str, tuple[set[str], set[str]]]:
"""
This builds a map of drive ids to their members (group and user emails).
E.g. {
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
}
"""
drive_ids = google_drive_connector.get_all_drive_ids()
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
drive_service = get_drive_service(
google_drive_connector.creds,
google_drive_connector.primary_admin_email,
)
for drive_id in drive_ids:
group_emails: set[str] = set()
user_emails: set[str] = set()
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
return drive_id_to_members_map
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
"""
This gets all the group emails.
"""
group_emails: set[str] = set()
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email)",
):
group_emails.add(group["email"])
return group_emails
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
) -> dict[str, set[str]]:
"""
This maps group emails to their member emails.
"""
group_to_member_map: dict[str, set[str]] = {}
for group_email in group_emails:
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.add(member["email"])
group_to_member_map[group_email] = group_member_emails
return group_to_member_map
def _build_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, set[str]],
) -> list[ExternalUserGroup]:
onyx_groups: list[ExternalUserGroup] = []
# Convert all drive member definitions to onyx groups
# This is because having drive level access means you have
# irrevocable access to all the files in the drive.
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
all_member_emails: set[str] = user_emails
for group_email in group_emails:
all_member_emails.update(group_email_to_member_emails_map[group_email])
onyx_groups.append(
ExternalUserGroup(
id=drive_id,
user_emails=list(all_member_emails),
)
)
# Convert all group member definitions to onyx groups
for group_email, member_emails in group_email_to_member_emails_map.items():
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(member_emails),
)
)
return onyx_groups
def gdrive_group_sync(
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
@@ -19,34 +130,23 @@ def gdrive_group_sync(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)
onyx_groups: list[ExternalUserGroup] = []
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["email"]
# Get all drive members
drive_id_to_members_map = _get_drive_members(google_drive_connector)
# Gather group member emails
group_member_emails: list[str] = []
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
# Get all group emails
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
if not group_member_emails:
continue
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
)
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
)
return onyx_groups

View File

@@ -1,19 +1,23 @@
import csv
import io
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Literal
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
@@ -23,257 +27,15 @@ from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
flow_type: SessionType
conversation_length: int
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}
def determine_flow_type(chat_session: ChatSession) -> SessionType:
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
def fetch_and_process_chat_session_history_minimal(
db_session: Session,
start: datetime,
end: datetime,
feedback_filter: QAFeedbackType | None = None,
limit: int | None = 500,
) -> list[ChatSessionMinimal]:
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
minimal_sessions = []
for chat_session in chat_sessions:
if not chat_session.messages:
continue
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
has_positive_feedback = any(
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
has_negative_feedback = any(
not feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
"mixed"
if has_positive_feedback and has_negative_feedback
else QAFeedbackType.LIKE
if has_positive_feedback
else QAFeedbackType.DISLIKE
if has_negative_feedback
else None
)
if feedback_filter:
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
continue
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
continue
flow_type = determine_flow_type(chat_session)
minimal_sessions.append(
ChatSessionMinimal(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=feedback_type,
flow_type=flow_type,
conversation_length=len(
[
m
for m in chat_session.messages
if m.message_type != MessageType.SYSTEM
]
),
)
)
return minimal_sessions
def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
@@ -319,7 +81,7 @@ def snapshot_from_chat_session(
except RuntimeError:
return None
flow_type = determine_flow_type(chat_session)
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
return ChatSessionSnapshot(
id=chat_session.id,
@@ -371,22 +133,38 @@ def get_user_chat_sessions(
@router.get("/admin/chat-session-history")
def get_chat_session_history(
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1),
feedback_type: QAFeedbackType | None = None,
start: datetime | None = None,
end: datetime | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[ChatSessionMinimal]:
return fetch_and_process_chat_session_history_minimal(
) -> PaginatedReturn[ChatSessionMinimal]:
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
db_session=db_session,
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
start_time=start_time,
end_time=end_time,
feedback_filter=feedback_type,
)
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
db_session=db_session,
start_time=start_time,
end_time=end_time,
feedback_filter=feedback_type,
)
return PaginatedReturn(
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@router.get("/admin/chat-session-history/{chat_session_id}")
def get_chat_session_admin(

View File

@@ -0,0 +1,218 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from onyx.auth.users import get_display_email
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import SessionType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
id: int
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
id=message.id,
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | None
flow_type: SessionType
conversation_length: int
@classmethod
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
list_of_message_feedbacks = [
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
]
session_feedback_type = None
if list_of_message_feedbacks:
if all(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.LIKE
elif not any(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.DISLIKE
else:
session_feedback_type = QAFeedbackType.MIXED
return cls(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=session_feedback_type,
flow_type=SessionType.SLACK
if chat_session.onyxbot_flow
else SessionType.CHAT,
conversation_length=len(
[
message
for message in chat_session.messages
if message.message_type != MessageType.SYSTEM
]
),
)
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}

View File

@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
seeded_logo_path: str | None = None
personas: list[CreatePersonaRequest] | None = None
personas: list[PersonaUpsertRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
@@ -128,7 +128,7 @@ def _seed_llms(
)
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:

View File

@@ -20,6 +20,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
@@ -100,6 +101,10 @@ def on_task_postrun(
if not task_id:
return
if task.name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# this is a cloud / all tenant task ... no postrun is needed
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")

View File

@@ -1,5 +1,6 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -7,12 +8,14 @@ from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -28,7 +31,7 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
self._try_updating_schedule()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
@@ -44,105 +47,154 @@ class DynamicTenantScheduler(PersistentScheduler):
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating task update")
self._update_tenant_tasks()
try:
self._try_updating_schedule()
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
self._last_reload = now
logger.info("Task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting task update process")
try:
logger.info("Fetching all IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
def _generate_schedule(
self, tenant_ids: list[str] | list[None]
) -> dict[str, dict[str, Any]]:
"""Given a list of tenant id's, generates a new beat schedule for celery."""
logger.info("Fetching tasks to schedule")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
new_schedule: dict[str, dict[str, Any]] = {}
if MULTI_TENANT:
# cloud tasks only need the single task beat across all tenants
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule",
"get_cloud_tasks_to_schedule",
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
cloud_tasks_to_schedule: list[
dict[str, Any]
] = get_cloud_tasks_to_schedule()
for task in cloud_tasks_to_schedule:
task_name = task["name"]
cloud_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
cloud_task["options"] = options
new_schedule[task_name] = cloud_task
current_schedule = self.schedule.items()
# regular task beats are multiplied across all tenants
get_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
tasks_to_schedule: list[dict[str, Any]] = get_tasks_to_schedule()
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new item: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
for tenant_id in tenant_ids:
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
for task in tasks_to_schedule:
task_name = task["name"]
tenant_task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {tenant_task_name}")
tenant_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
return new_schedule
# Ensure changes are persisted
self.sync()
def _try_updating_schedule(self) -> None:
"""Only updates the actual beat schedule on the celery app when it changes"""
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
logger.info("_try_updating_schedule starting")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
current_tenants = set()
for task_name, _ in current_schedule:
task_name = cast(str, task_name)
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
continue
if "_" in task_name:
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# -> "12345678-abcd-efgh-ijkl-12345678"
current_tenants.add(task_name.split("_")[-1])
logger.info(f"Found {len(current_tenants)} existing items in schedule")
for tenant_id in tenant_ids:
if tenant_id not in current_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
new_schedule = self._generate_schedule(tenant_ids)
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
logger.info(
"_try_updating_schedule: Current schedule is up to date, no changes needed"
)
return
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("_try_updating_schedule: Schedule updated successfully")
@staticmethod
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
"""Compare schedules to determine if an update is needed.
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
@beat_init.connect

View File

@@ -17,7 +17,7 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.indexing.tasks import (
from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT

View File

@@ -2,25 +2,43 @@ from datetime import timedelta
from typing import Any
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
# we might be able to reduce this greatly if we can run a unified
# loop across all tenants rather than tasks per tenant
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# tasks that only run in the cloud
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that run in either self-hosted on cloud
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -29,16 +47,7 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -47,7 +56,7 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -65,7 +74,7 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -84,7 +93,7 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -93,12 +102,25 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
if not MULTI_TENANT:
tasks_to_schedule.append(
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
@@ -114,5 +136,9 @@ if LLM_MODEL_UPDATE_API_URL:
)
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -44,7 +44,7 @@ def check_for_connector_deletion_task(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap

View File

@@ -22,9 +22,9 @@ from ee.onyx.external_permissions.sync_params import (
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
@@ -99,7 +99,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap

View File

@@ -22,7 +22,7 @@ from ee.onyx.external_permissions.sync_params import (
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -99,7 +99,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,519 @@
import time
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
# now = time.monotonic()
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
# try:
# # this is unintuitive, but it checks if the parent pid is still running
# os.kill(self.parent_pid, 0)
# except Exception:
# logger.exception("IndexingCallback - parent pid check exceptioned")
# raise
# self.last_parent_check = now
try:
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -68,20 +68,22 @@ class Metric(BaseModel):
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str | None) -> None:
# Convert value to appropriate type
float_value = (
float(self.value) if isinstance(self.value, (int, float)) else None
)
int_value = int(self.value) if isinstance(self.value, int) else None
string_value = str(self.value) if isinstance(self.value, str) else None
bool_value = bool(self.value) if isinstance(self.value, bool) else None
if (
float_value is None
and int_value is None
and string_value is None
and bool_value is None
):
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
int_value = None
string_value = None
# NOTE: have to do bool first, since `isinstance(True, int)` is true
# e.g. bool is a subclass of int
if isinstance(self.value, bool):
bool_value = self.value
elif isinstance(self.value, int):
int_value = self.value
elif isinstance(self.value, float):
float_value = self.value
elif isinstance(self.value, str):
string_value = self.value
else:
task_logger.error(
f"Invalid metric value type: {type(self.value)} "
f"({self.value}) for metric {self.name}."
@@ -183,35 +185,41 @@ def _build_connector_start_latency_metric(
)
def _build_run_success_metric(
cc_pair: ConnectorCredentialPair, recent_attempt: IndexAttempt, redis_std: Redis
) -> Metric | None:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=recent_attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {recent_attempt.id} because it has already been "
"emitted"
)
return None
if recent_attempt.status in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
return Metric(
key=metric_key,
name="connector_run_succeeded",
value=recent_attempt.status == IndexingStatus.SUCCESS,
tags={"source": str(cc_pair.connector.source)},
def _build_run_success_metrics(
cc_pair: ConnectorCredentialPair,
recent_attempts: list[IndexAttempt],
redis_std: Redis,
) -> list[Metric]:
metrics = []
for attempt in recent_attempts:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=attempt.id,
)
return None
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {attempt.id} because it has already been "
"emitted"
)
continue
if attempt.status in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
metrics.append(
Metric(
key=metric_key,
name="connector_run_succeeded",
value=attempt.status == IndexingStatus.SUCCESS,
tags={"source": str(cc_pair.connector.source)},
)
)
return metrics
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
@@ -224,7 +232,7 @@ def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Me
metrics = []
for cc_pair in cc_pairs:
# Get most recent attempt in the last hour
# Get all attempts in the last hour
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
@@ -232,31 +240,29 @@ def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Me
IndexAttempt.time_created >= one_hour_ago,
)
.order_by(IndexAttempt.time_created.desc())
.limit(2)
.all()
)
recent_attempt = recent_attempts[0] if recent_attempts else None
most_recent_attempt = recent_attempts[0] if recent_attempts else None
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
# if no metric to emit, skip
if not recent_attempt:
if most_recent_attempt is None:
continue
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, recent_attempt, second_most_recent_attempt, redis_std
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Connector run success/failure
run_success_metric = _build_run_success_metric(
cc_pair, recent_attempt, redis_std
run_success_metrics = _build_run_success_metrics(
cc_pair, recent_attempts, redis_std
)
if run_success_metric:
metrics.append(run_success_metric)
metrics.extend(run_success_metrics)
return metrics

View File

@@ -13,11 +13,11 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.tasks import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -86,7 +86,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap

View File

@@ -101,10 +101,17 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
for doc in doc_batch:
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
@@ -120,6 +127,9 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@@ -277,8 +287,6 @@ def _run_indexing(
tenant_id=tenant_id,
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
@@ -347,16 +355,15 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
index_pipeline_result = indexing_pipeline(
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@@ -365,9 +372,6 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
update_docs_indexed(
@@ -378,6 +382,9 @@ def _run_indexing(
docs_removed_from_index=0,
)
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0

View File

@@ -25,7 +25,7 @@ from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.persona import get_prompts_by_ids
from onyx.db.prompts import get_prompts_by_ids
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest

View File

@@ -1,12 +1,13 @@
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Persona
from onyx.db.persona import get_default_prompt__read_only
from onyx.db.prompts import get_default_prompt
from onyx.db.search_settings import get_multilingual_expansion
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -97,11 +98,12 @@ def compute_max_document_tokens(
def compute_max_document_tokens_for_persona(
db_session: Session,
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,

View File

@@ -7,26 +7,6 @@ from onyx.db.models import ChatMessage
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_onyx_msg_to_langchain(

View File

@@ -79,6 +79,8 @@ KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
# NOTE: we use this timeout / 4 in various places to refresh a lock
# might be worth separating this timeout into separate timeouts for each situation
CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -198,6 +200,7 @@ class SessionType(str, Enum):
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
MIXED = "mixed" # User likes some answers and dislikes other, used for chat session metrics
class SearchFeedbackType(str, Enum):
@@ -291,6 +294,8 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
@@ -304,6 +309,13 @@ class OnyxCeleryPriority(int, Enum):
LOWEST = auto()
# a prefix used to distinguish system wide tasks in the cloud
ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
# the tenant id we use for system level redis operations
ONYX_CLOUD_TENANT_ID = "cloud"
class OnyxCeleryTask:
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -331,6 +343,8 @@ class OnyxCeleryTask:
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15

View File

@@ -258,7 +258,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
user_emails.append(email)
return user_emails
def _get_all_drive_ids(self) -> set[str]:
def get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
@@ -353,7 +353,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
) -> Iterator[GoogleDriveFileType]:
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
all_drive_ids: set[str] = self.get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
@@ -437,7 +437,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# If all 3 are true, we already yielded from get_all_files_for_oauth
return
all_drive_ids = self._get_all_drive_ids()
all_drive_ids = self.get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:

View File

@@ -252,6 +252,7 @@ def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"drive_id": file.get("driveId"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),

View File

@@ -19,7 +19,7 @@ FILE_FIELDS = (
"shortcutDetails, owners(emailAddress), size)"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"nextPageToken, files(mimeType, driveId, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"

View File

@@ -12,8 +12,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
@@ -35,13 +33,8 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -101,7 +94,7 @@ def _add_user_filters(
def get_connector_credential_pairs_for_user(
db_session: Session,
user: User | None | SystemUser,
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
@@ -112,7 +105,6 @@ def get_connector_credential_pairs_for_user(
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
@@ -123,11 +115,12 @@ def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
return get_connector_credential_pairs_for_user(
db_session=db_session,
user=SYSTEM_USER,
ids=ids,
)
stmt = select(ConnectorCredentialPair).distinct()
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
def add_deletion_failure_message(
@@ -162,7 +155,7 @@ def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
user: User | None | SystemUser,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
@@ -178,18 +171,17 @@ def get_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
return get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
user=SYSTEM_USER,
)
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None | SystemUser,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
@@ -203,11 +195,10 @@ def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
) -> ConnectorCredentialPair | None:
return get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=SYSTEM_USER,
)
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_last_successful_attempt_time(

View File

@@ -1,11 +1 @@
from typing import Final
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
class SystemUser:
"""Represents the system user for internal operations"""
SYSTEM_USER: Final = SystemUser()

View File

@@ -14,8 +14,6 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -44,17 +42,11 @@ PUBLIC_CREDENTIAL_ID = 0
def _add_user_filters(
stmt: Select,
user: User | None | SystemUser,
user: User | None,
get_editable: bool = True,
) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
if user is None:
if not DISABLE_AUTH:
raise ValueError("Anonymous users are not allowed to access credentials")
@@ -159,7 +151,7 @@ def fetch_credentials_for_user(
def fetch_credential_by_id_for_user(
credential_id: int,
user: User | None | SystemUser,
user: User | None,
db_session: Session,
get_editable: bool = True,
) -> Credential | None:
@@ -179,16 +171,16 @@ def fetch_credential_by_id(
db_session: Session,
credential_id: int,
) -> Credential | None:
return fetch_credential_by_id_for_user(
credential_id=credential_id,
user=SYSTEM_USER,
db_session=db_session,
)
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
def fetch_credentials_by_source_for_user(
db_session: Session,
user: User | None | SystemUser,
user: User | None,
document_source: DocumentSource | None = None,
get_editable: bool = True,
) -> list[Credential]:
@@ -202,11 +194,9 @@ def fetch_credentials_by_source(
db_session: Session,
document_source: DocumentSource | None = None,
) -> list[Credential]:
return fetch_credentials_by_source_for_user(
db_session=db_session,
user=SYSTEM_USER,
document_source=document_source,
)
base_query = select(Credential).where(Credential.source == document_source)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
def swap_credentials_connector(

View File

@@ -1,6 +1,7 @@
import contextlib
import time
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
@@ -13,6 +14,7 @@ from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import tuple_
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
@@ -226,10 +228,13 @@ def get_document_counts_for_cc_pairs(
func.count(),
)
.where(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids)
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
@@ -382,18 +387,40 @@ def upsert_document_by_connector_credential_pair(
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
has_been_indexed=False,
)
)
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
# this must be `on_conflict_do_nothing` rather than `on_conflict_do_update`
# since we don't want to update the `has_been_indexed` field for documents
# that already exist
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
db_session.execute(on_conflict_stmt)
db_session.commit()
def mark_document_as_indexed_for_cc_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
document_ids: Iterable[str],
) -> None:
"""Should be called only after a successful index operation for a batch."""
db_session.execute(
update(DocumentByConnectorCredentialPair)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
DocumentByConnectorCredentialPair.id.in_(document_ids),
)
)
.values(has_been_indexed=True)
)
def update_docs_updated_at__no_commit(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,

View File

@@ -15,8 +15,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
@@ -37,13 +35,8 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -494,7 +487,7 @@ def fetch_document_sets(
def fetch_all_document_sets_for_user(
db_session: Session,
user: User | None | SystemUser,
user: User | None,
get_editable: bool = True,
) -> Sequence[DocumentSetDBModel]:
stmt = select(DocumentSetDBModel).distinct()
@@ -502,15 +495,6 @@ def fetch_all_document_sets_for_user(
return db_session.scalars(stmt).all()
def fetch_all_document_sets(
db_session: Session,
) -> Sequence[DocumentSetDBModel]:
return fetch_all_document_sets_for_user(
db_session=db_session,
user=SYSTEM_USER,
)
def fetch_documents_for_document_set_paginated(
document_set_id: int,
db_session: Session,

View File

@@ -240,8 +240,11 @@ class SqlEngine:
def get_all_tenant_ids() -> list[str] | list[None]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
result = session.execute(
text(

View File

@@ -941,6 +941,12 @@ class DocumentByConnectorCredentialPair(Base):
ForeignKey("credential.id"), primary_key=True
)
# used to better keep track of document counts at a connector level
# e.g. if a document is added as part of permission syncing, it should
# not be counted as part of the connector's document count until
# the actual indexing is complete
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector"
)
@@ -955,6 +961,14 @@ class DocumentByConnectorCredentialPair(Base):
"credential_id",
unique=False,
),
# Index to optimize get_document_counts_for_cc_pairs query pattern
Index(
"idx_document_cc_pair_counts",
"connector_id",
"credential_id",
"has_been_indexed",
unique=False,
),
)

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from functools import lru_cache
from uuid import UUID
from fastapi import HTTPException
@@ -8,7 +7,6 @@ from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
@@ -23,9 +21,6 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -37,8 +32,8 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -46,13 +41,8 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -114,14 +104,8 @@ def _add_user_filters(
return stmt.where(where_clause)
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
def fetch_persona_by_id_for_user(
db_session: Session,
persona_id: int,
user: User | None | SystemUser,
get_editable: bool = True,
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
@@ -134,17 +118,6 @@ def fetch_persona_by_id_for_user(
return persona
def fetch_persona_by_id(
db_session: Session,
persona_id: int,
) -> Persona:
return fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=persona_id,
user=SYSTEM_USER,
)
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
@@ -205,7 +178,7 @@ def make_persona_private(
def create_update_persona(
persona_id: int | None,
create_persona_request: CreatePersonaRequest,
create_persona_request: PersonaUpsertRequest,
user: User | None,
db_session: Session,
) -> PersonaSnapshot:
@@ -213,14 +186,36 @@ def create_update_persona(
# Permission to actually use these is checked later
try:
persona_data = {
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.model_dump(exclude={"users", "groups"}),
}
all_prompt_ids = create_persona_request.prompt_ids
persona = upsert_persona(**persona_data)
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
persona = upsert_persona(
persona_id=persona_id,
user=user,
db_session=db_session,
description=create_persona_request.description,
name=create_persona_request.name,
prompt_ids=all_prompt_ids,
document_set_ids=create_persona_request.document_set_ids,
tool_ids=create_persona_request.tool_ids,
is_public=create_persona_request.is_public,
recency_bias=create_persona_request.recency_bias,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
icon_color=create_persona_request.icon_color,
icon_shape=create_persona_request.icon_shape,
uploaded_image_id=create_persona_request.uploaded_image_id,
display_priority=create_persona_request.display_priority,
remove_image=create_persona_request.remove_image,
search_start_date=create_persona_request.search_start_date,
label_ids=create_persona_request.label_ids,
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
)
versioned_make_persona_private = fetch_versioned_implementation(
"onyx.db.persona", "make_persona_private"
@@ -286,27 +281,9 @@ def update_persona_public_status(
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_deleted: bool = False,
) -> Sequence[Prompt]:
stmt = select(Prompt).where(
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
if not include_default:
stmt = stmt.where(Prompt.default_prompt.is_(False))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None | SystemUser,
user: User | None,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
@@ -336,10 +313,10 @@ def get_personas_for_user(
def get_personas(db_session: Session) -> Sequence[Persona]:
return get_personas_for_user(
user=SYSTEM_USER,
db_session=db_session,
)
stmt = select(Persona).distinct()
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.execute(stmt).unique().scalars().all()
def mark_persona_as_deleted(
@@ -395,65 +372,6 @@ def update_all_personas_display_priority(
db_session.commit()
def upsert_prompt(
user: User | None,
name: str,
description: str,
system_prompt: str,
task_prompt: str,
include_citations: bool,
datetime_aware: bool,
personas: list[Persona] | None,
db_session: Session,
prompt_id: int | None = None,
default_prompt: bool = True,
commit: bool = True,
) -> Prompt:
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
if commit:
db_session.commit()
else:
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt
def upsert_persona(
user: User | None,
name: str,
@@ -498,6 +416,15 @@ def upsert_persona(
persona_name=name, user=user, db_session=db_session
)
if existing_persona:
# this checks if the user has permission to edit the persona
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# Fetch and attach tools by IDs
tools = None
if tool_ids is not None:
@@ -543,15 +470,6 @@ def upsert_persona(
if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
@@ -640,16 +558,6 @@ def upsert_persona(
return persona
def mark_prompt_as_deleted(
prompt_id: int,
user: User | None,
db_session: Session,
) -> None:
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
prompt.deleted = True
db_session.commit()
def delete_old_default_personas(
db_session: Session,
) -> None:
@@ -687,69 +595,6 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return list(prompts)
def get_prompt_by_id(
prompt_id: int,
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Prompt:
stmt = select(Prompt).where(Prompt.id == prompt_id)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise ValueError(
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
)
return prompt
def _get_default_prompt(db_session: Session) -> Prompt:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_default_prompt(db_session: Session) -> Prompt:
return _get_default_prompt(db_session)
@lru_cache()
def get_default_prompt__read_only() -> Prompt:
"""Due to the way lru_cache / SQLAlchemy works, this can cause issues
when trying to attach the returned `Prompt` object to a `Persona`. If you are
doing anything other than reading, you should use the `get_default_prompt`
method instead."""
with Session(get_sqlalchemy_engine()) as db_session:
return _get_default_prompt(db_session)
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
# a direct mapping indicating whether a user has access to a specific persona?
def get_persona_by_id(
@@ -821,22 +666,6 @@ def get_personas_by_ids(
return personas
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:

119
backend/onyx/db/prompts.py Normal file
View File

@@ -0,0 +1,119 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import User
from onyx.utils.logger import setup_logger
# Note: As prompts are fairly innocuous/harmless, there are no protections
# to prevent users from messing with prompts of other users.
logger = setup_logger()
def _get_default_prompt(db_session: Session) -> Prompt:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_default_prompt(db_session: Session) -> Prompt:
return _get_default_prompt(db_session)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return list(prompts)
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def build_prompt_name_from_persona_name(persona_name: str) -> str:
return f"default-prompt__{persona_name}"
def upsert_prompt(
db_session: Session,
user: User | None,
name: str,
system_prompt: str,
task_prompt: str,
datetime_aware: bool,
prompt_id: int | None = None,
personas: list[Persona] | None = None,
include_citations: bool = False,
default_prompt: bool = True,
# Support backwards compatibility
description: str | None = None,
) -> Prompt:
if description is None:
description = f"Default prompt for {name}"
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt

View File

@@ -12,9 +12,9 @@ from onyx.db.models import Persona
from onyx.db.models import Persona__DocumentSet
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_default_prompt
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,

View File

@@ -21,6 +21,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
from onyx.db.document import prepare_to_modify_documents
from onyx.db.document import update_docs_chunk_count__no_commit
from onyx.db.document import update_docs_last_modified__no_commit
@@ -55,12 +56,23 @@ class DocumentBatchPrepareContext(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class IndexingPipelineResult(BaseModel):
# number of documents that are completely new (e.g. did
# not exist as a part of this OR any other connector)
new_docs: int
# NOTE: need total_docs, since the pipeline can skip some docs
# (e.g. not even insert them into Postgres)
total_docs: int
# number of chunks that were inserted into Vespa
total_chunks: int
class IndexingPipelineProtocol(Protocol):
def __call__(
self,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
...
@@ -147,10 +159,12 @@ def index_doc_batch_with_handler(
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
) -> IndexingPipelineResult:
index_pipeline_result = IndexingPipelineResult(
new_docs=0, total_docs=len(document_batch), total_chunks=0
)
try:
r = index_doc_batch(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
document_index=document_index,
@@ -203,7 +217,7 @@ def index_doc_batch_with_handler(
else:
pass
return r
return index_pipeline_result
def index_doc_batch_prepare(
@@ -227,6 +241,15 @@ def index_doc_batch_prepare(
if not ignore_time_skip
else documents
)
if len(updatable_docs) != len(documents):
updatable_doc_ids = [doc.id for doc in updatable_docs]
skipped_doc_ids = [
doc.id for doc in documents if doc.id not in updatable_doc_ids
]
logger.info(
f"Skipping {len(skipped_doc_ids)} documents "
f"because they are up to date. Skipped doc IDs: {skipped_doc_ids}"
)
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
@@ -263,21 +286,6 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
# Remove any NUL characters from title/semantic_id
# This is a known issue with the Zendesk connector
# Postgres cannot handle NUL characters in text fields
if document.title:
document.title = document.title.replace("\x00", "")
if document.semantic_identifier:
document.semantic_identifier = document.semantic_identifier.replace(
"\x00", ""
)
# Remove NUL characters from all sections
for section in document.sections:
if section.text is not None:
section.text = section.text.replace("\x00", "")
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
@@ -333,7 +341,7 @@ def index_doc_batch(
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements
@@ -359,7 +367,18 @@ def index_doc_batch(
db_session=db_session,
)
if not ctx:
return 0, 0
# even though we didn't actually index anything, we should still
# mark them as "completed" for the CC Pair in order to make the
# counts match
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
return IndexingPipelineResult(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
@@ -425,7 +444,8 @@ def index_doc_batch(
]
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
"Indexing the following chunks: "
f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
@@ -440,14 +460,17 @@ def index_doc_batch(
),
)
successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
]
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Successful IDs: {successful_doc_ids}"
)
last_modified_ids = []
ids_to_new_updated_at = {}
for doc in successful_docs:
for doc in ctx.updatable_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
@@ -469,11 +492,24 @@ def index_doc_batch(
db_session=db_session,
)
# these documents can now be counted as part of the CC Pairs
# document count, so we need to mark them as indexed
# NOTE: even documents we skipped since they were already up
# to date should be counted here in order to maintain parity
# between CC Pair and index attempt counts
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
db_session.commit()
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
result = IndexingPipelineResult(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
)
return result

View File

@@ -64,7 +64,6 @@ from onyx.server.features.input_prompt.api import (
from onyx.server.features.notifications.api import router as notification_router
from onyx.server.features.persona.api import admin_router as admin_persona_router
from onyx.server.features.persona.api import basic_router as persona_router
from onyx.server.features.prompt.api import basic_router as prompt_router
from onyx.server.features.tool.api import admin_router as admin_tool_router
from onyx.server.features.tool.api import router as tool_router
from onyx.server.gpts.api import router as gpts_router
@@ -296,7 +295,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
include_router_with_global_prefix_prepended(application, state_router)

View File

@@ -118,32 +118,6 @@ You should always get right to the point, and never use extraneous language.
"""
# This is only for visualization for the users to specify their own prompts
# The actual flow does not work like this
PARAMATERIZED_PROMPT = f"""
{{system_prompt}}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f"""
{{system_prompt}}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
# CURRENTLY DISABLED, CANNOT USE THIS ONE
# Default chain-of-thought style json prompt which uses multiple docs
# This one has a section for the LLM to output some non-answer "thoughts"

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -147,7 +147,7 @@ class RedisConnectorPermissionSync:
for doc_perm in new_permissions:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time

View File

@@ -7,7 +7,7 @@ from celery import Celery
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -125,7 +125,7 @@ class RedisConnectorPrune:
for doc_id in documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time

View File

@@ -12,9 +12,9 @@ from onyx.db.models import DocumentSet as DocumentSetDBModel
from onyx.db.models import Persona
from onyx.db.models import Prompt as PromptDBModel
from onyx.db.models import Tool as ToolDBModel
from onyx.db.persona import get_prompt_by_name
from onyx.db.persona import upsert_persona
from onyx.db.persona import upsert_prompt
from onyx.db.prompts import get_prompt_by_name
from onyx.db.prompts import upsert_prompt
def load_prompts_from_yaml(
@@ -26,6 +26,7 @@ def load_prompts_from_yaml(
all_prompts = data.get("prompts", [])
for prompt in all_prompts:
upsert_prompt(
db_session=db_session,
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
@@ -36,9 +37,8 @@ def load_prompts_from_yaml(
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
db_session.commit()
def load_input_prompts_from_yaml(

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from ee.onyx.server.query_history.models import ChatSessionMinimal
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import DocumentErrorSummary
@@ -212,6 +213,7 @@ PaginatedType = TypeVar(
IndexAttemptSnapshot,
FullUserSnapshot,
InvitedUserSnapshot,
ChatSessionMinimal,
)

View File

@@ -14,7 +14,6 @@ from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.chat.prompt_builder.utils import build_dummy_prompt
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
@@ -36,19 +35,21 @@ from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared_users
from onyx.db.persona import update_persona_visibility
from onyx.db.prompts import build_prompt_name_from_persona_name
from onyx.db.prompts import upsert_prompt
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
from onyx.server.features.persona.models import PersonaLabelResponse
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PromptTemplateResponse
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.models import PromptSnapshot
from onyx.server.models import DisplayPriorityRequest
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
@@ -173,18 +174,37 @@ def upload_file(
@basic_router.post("")
def create_persona(
create_persona_request: CreatePersonaRequest,
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> PersonaSnapshot:
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids
and len(persona_upsert_request.prompt_ids) > 0
else None
)
prompt = upsert_prompt(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
prompt_snapshot = PromptSnapshot.from_model(prompt)
persona_upsert_request.prompt_ids = [prompt.id]
persona_snapshot = create_update_persona(
persona_id=None,
create_persona_request=create_persona_request,
create_persona_request=persona_upsert_request,
user=user,
db_session=db_session,
)
persona_snapshot.prompts = [prompt_snapshot]
create_milestone_and_report(
user=user,
distinct_id=tenant_id or "N/A",
@@ -202,16 +222,37 @@ def create_persona(
@basic_router.patch("/{persona_id}")
def update_persona(
persona_id: int,
update_persona_request: CreatePersonaRequest,
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return create_update_persona(
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids
and len(persona_upsert_request.prompt_ids) > 0
else None
)
prompt = upsert_prompt(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
datetime_aware=False,
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
prompt_snapshot = PromptSnapshot.from_model(prompt)
persona_upsert_request.prompt_ids = [prompt.id]
persona_snapshot = create_update_persona(
persona_id=persona_id,
create_persona_request=update_persona_request,
create_persona_request=persona_upsert_request,
user=user,
db_session=db_session,
)
persona_snapshot.prompts = [prompt_snapshot]
return persona_snapshot
class PersonaLabelPatchRequest(BaseModel):
@@ -365,22 +406,6 @@ def get_persona(
)
@basic_router.get("/utils/prompt-explorer")
def build_final_template_prompt(
system_prompt: str,
task_prompt: str,
retrieval_disabled: bool = False,
_: User | None = Depends(current_user),
) -> PromptTemplateResponse:
return PromptTemplateResponse(
final_prompt_template=build_dummy_prompt(
system_prompt=system_prompt,
task_prompt=task_prompt,
retrieval_disabled=retrieval_disabled,
)
)
@basic_router.post("/assistant-prompt-refresh")
def build_assistant_prompts(
generate_persona_prompt_request: GenerateStarterMessageRequest,

View File

@@ -7,9 +7,9 @@ from pydantic import Field
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.models import Persona
from onyx.db.models import PersonaLabel
from onyx.db.models import Prompt
from onyx.db.models import StarterMessage
from onyx.server.features.document_set.models import DocumentSet
from onyx.server.features.prompt.models import PromptSnapshot
from onyx.server.features.tool.models import ToolSnapshot
from onyx.server.models import MinimalUserSnapshot
from onyx.utils.logger import setup_logger
@@ -18,6 +18,34 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class PromptSnapshot(BaseModel):
id: int
name: str
description: str
system_prompt: str
task_prompt: str
include_citations: bool
datetime_aware: bool
default_prompt: bool
# Not including persona info, not needed
@classmethod
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
if prompt.deleted:
raise ValueError("Prompt has been deleted")
return PromptSnapshot(
id=prompt.id,
name=prompt.name,
description=prompt.description,
system_prompt=prompt.system_prompt,
task_prompt=prompt.task_prompt,
include_citations=prompt.include_citations,
datetime_aware=prompt.datetime_aware,
default_prompt=prompt.default_prompt,
)
# More minimal request for generating a persona prompt
class GenerateStarterMessageRequest(BaseModel):
name: str
@@ -27,32 +55,35 @@ class GenerateStarterMessageRequest(BaseModel):
generation_count: int
class CreatePersonaRequest(BaseModel):
class PersonaUpsertRequest(BaseModel):
name: str
description: str
system_prompt: str
task_prompt: str
document_set_ids: list[int]
num_chunks: float
llm_relevance_filter: bool
include_citations: bool
is_public: bool
llm_filter_extraction: bool
recency_bias: RecencyBiasSetting
prompt_ids: list[int]
document_set_ids: list[int]
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
tool_ids: list[int]
llm_filter_extraction: bool
llm_relevance_filter: bool
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
# For Private Personas, who should be able to access these
users: list[UUID] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
tool_ids: list[int]
icon_color: str | None = None
icon_shape: int | None = None
uploaded_image_id: str | None = None # New field for uploaded image
remove_image: bool | None = None
is_default_persona: bool = False
display_priority: int | None = None
uploaded_image_id: str | None = None # New field for uploaded image
search_start_date: datetime | None = None
label_ids: list[int] | None = None
is_default_persona: bool = False
display_priority: int | None = None
class PersonaSnapshot(BaseModel):

View File

@@ -1,152 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from starlette import status
from onyx.auth.users import current_user
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.db.persona import get_personas_by_ids
from onyx.db.persona import get_prompt_by_id
from onyx.db.persona import get_prompts
from onyx.db.persona import mark_prompt_as_deleted
from onyx.db.persona import upsert_prompt
from onyx.server.features.prompt.models import CreatePromptRequest
from onyx.server.features.prompt.models import PromptSnapshot
from onyx.utils.logger import setup_logger
# Note: As prompts are fairly innocuous/harmless, there are no protections
# to prevent users from messing with prompts of other users.
logger = setup_logger()
basic_router = APIRouter(prefix="/prompt")
def create_update_prompt(
prompt_id: int | None,
create_prompt_request: CreatePromptRequest,
user: User | None,
db_session: Session,
) -> PromptSnapshot:
personas = (
list(
get_personas_by_ids(
persona_ids=create_prompt_request.persona_ids,
db_session=db_session,
)
)
if create_prompt_request.persona_ids
else []
)
prompt = upsert_prompt(
prompt_id=prompt_id,
user=user,
name=create_prompt_request.name,
description=create_prompt_request.description,
system_prompt=create_prompt_request.system_prompt,
task_prompt=create_prompt_request.task_prompt,
include_citations=create_prompt_request.include_citations,
datetime_aware=create_prompt_request.datetime_aware,
personas=personas,
db_session=db_session,
)
return PromptSnapshot.from_model(prompt)
@basic_router.post("")
def create_prompt(
create_prompt_request: CreatePromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PromptSnapshot:
try:
return create_update_prompt(
prompt_id=None,
create_prompt_request=create_prompt_request,
user=user,
db_session=db_session,
)
except ValueError as ve:
logger.exception(ve)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to create Persona, invalid info.",
)
except Exception as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.",
)
@basic_router.patch("/{prompt_id}")
def update_prompt(
prompt_id: int,
update_prompt_request: CreatePromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PromptSnapshot:
try:
return create_update_prompt(
prompt_id=prompt_id,
create_prompt_request=update_prompt_request,
user=user,
db_session=db_session,
)
except ValueError as ve:
logger.exception(ve)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to create Persona, invalid info.",
)
except Exception as e:
logger.exception(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred. Please try again later.",
)
@basic_router.delete("/{prompt_id}")
def delete_prompt(
prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
mark_prompt_as_deleted(
prompt_id=prompt_id,
user=user,
db_session=db_session,
)
@basic_router.get("")
def list_prompts(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[PromptSnapshot]:
user_id = user.id if user is not None else None
return [
PromptSnapshot.from_model(prompt)
for prompt in get_prompts(user_id=user_id, db_session=db_session)
]
@basic_router.get("/{prompt_id}")
def get_prompt(
prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PromptSnapshot:
return PromptSnapshot.from_model(
get_prompt_by_id(
prompt_id=prompt_id,
user=user,
db_session=db_session,
)
)

View File

@@ -1,41 +0,0 @@
from pydantic import BaseModel
from onyx.db.models import Prompt
class CreatePromptRequest(BaseModel):
name: str
description: str
system_prompt: str
task_prompt: str
include_citations: bool = False
datetime_aware: bool = False
persona_ids: list[int] | None = None
class PromptSnapshot(BaseModel):
id: int
name: str
description: str
system_prompt: str
task_prompt: str
include_citations: bool
datetime_aware: bool
default_prompt: bool
# Not including persona info, not needed
@classmethod
def from_model(cls, prompt: Prompt) -> "PromptSnapshot":
if prompt.deleted:
raise ValueError("Prompt has been deleted")
return PromptSnapshot(
id=prompt.id,
name=prompt.name,
description=prompt.description,
system_prompt=prompt.system_prompt,
task_prompt=prompt.task_prompt,
include_citations=prompt.include_citations,
datetime_aware=prompt.datetime_aware,
default_prompt=prompt.default_prompt,
)

View File

@@ -27,6 +27,8 @@ from onyx.server.manage.models import SlackBot
from onyx.server.manage.models import SlackBotCreationRequest
from onyx.server.manage.models import SlackChannelConfig
from onyx.server.manage.models import SlackChannelConfigCreationRequest
from onyx.server.manage.validate_tokens import validate_app_token
from onyx.server.manage.validate_tokens import validate_bot_token
from onyx.utils.telemetry import create_milestone_and_report
@@ -222,6 +224,9 @@ def create_bot(
_: User | None = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> SlackBot:
validate_app_token(slack_bot_creation_request.app_token)
validate_bot_token(slack_bot_creation_request.bot_token)
slack_bot_model = insert_slack_bot(
db_session=db_session,
name=slack_bot_creation_request.name,
@@ -248,6 +253,8 @@ def patch_bot(
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> SlackBot:
validate_bot_token(slack_bot_creation_request.bot_token)
validate_app_token(slack_bot_creation_request.app_token)
slack_bot_model = update_slack_bot(
db_session=db_session,
slack_bot_id=slack_bot_id,

View File

@@ -0,0 +1,43 @@
import requests
from fastapi import HTTPException
SLACK_API_URL = "https://slack.com/api/auth.test"
SLACK_CONNECTIONS_OPEN_URL = "https://slack.com/api/apps.connections.open"
def validate_bot_token(bot_token: str) -> bool:
headers = {"Authorization": f"Bearer {bot_token}"}
response = requests.post(SLACK_API_URL, headers=headers)
if response.status_code != 200:
raise HTTPException(
status_code=500, detail="Error communicating with Slack API."
)
data = response.json()
if not data.get("ok", False):
raise HTTPException(
status_code=400,
detail=f"Invalid bot token: {data.get('error', 'Unknown error')}",
)
return True
def validate_app_token(app_token: str) -> bool:
headers = {"Authorization": f"Bearer {app_token}"}
response = requests.post(SLACK_CONNECTIONS_OPEN_URL, headers=headers)
if response.status_code != 200:
raise HTTPException(
status_code=500, detail="Error communicating with Slack API."
)
data = response.json()
if not data.get("ok", False):
raise HTTPException(
status_code=400,
detail=f"Invalid app token: {data.get('error', 'Unknown error')}",
)
return True

View File

@@ -108,7 +108,7 @@ def upsert_ingestion_doc(
tenant_id=tenant_id,
)
new_doc, __chunk_count = indexing_pipeline(
indexing_pipeline_result = indexing_pipeline(
document_batch=[document],
index_attempt_metadata=IndexAttemptMetadata(
connector_id=cc_pair.connector_id,
@@ -150,4 +150,7 @@ def upsert_ingestion_doc(
),
)
return IngestionResult(document_id=document.id, already_existed=not bool(new_doc))
return IngestionResult(
document_id=document.id,
already_existed=indexing_pipeline_result.new_docs > 0,
)

View File

@@ -18,7 +18,7 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.persona import get_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.persona import upsert_prompt
from onyx.db.prompts import upsert_prompt
from onyx.db.tools import get_tool_by_name
from onyx.utils.logger import setup_logger

View File

@@ -479,7 +479,10 @@ def get_max_document_tokens(
raise HTTPException(status_code=404, detail="Persona not found")
return MaxSelectedDocumentTokens(
max_tokens=compute_max_document_tokens_for_persona(persona),
max_tokens=compute_max_document_tokens_for_persona(
db_session=db_session,
persona=persona,
),
)

View File

@@ -35,7 +35,7 @@ class LongTermLogger:
def _cleanup_old_files(self, category_path: Path) -> None:
try:
files = sorted(
category_path.glob("*.json"),
[f for f in category_path.glob("*.json") if f.is_file()],
key=lambda x: x.stat().st_mtime, # Sort by modification time
reverse=True,
)

View File

@@ -108,6 +108,7 @@ logger = getLogger(__name__)
# class MessageSnapshot(BaseModel):
# id: int
# message: str
# message_type: MessageType
# documents: list[AbridgedSearchDoc]

View File

@@ -133,3 +133,25 @@ class ChatSessionManager:
)
for msg in response.json()["messages"]
]
@staticmethod
def create_chat_message_feedback(
message_id: int,
is_positive: bool,
user_performing_action: DATestUser | None = None,
feedback_text: str | None = None,
predefined_feedback: str | None = None,
) -> None:
response = requests.post(
url=f"{API_SERVER_URL}/chat/create-chat-message-feedback",
json={
"chat_message_id": message_id,
"is_positive": is_positive,
"feedback_text": feedback_text,
"predefined_feedback": predefined_feedback,
},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()

View File

@@ -1,9 +1,11 @@
from uuid import UUID
from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
@@ -16,6 +18,9 @@ class PersonaManager:
def create(
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
include_citations: bool = False,
num_chunks: float = 5,
llm_relevance_filter: bool = True,
is_public: bool = True,
@@ -28,32 +33,38 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
category_id: int | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
description = description or f"Description for {name}"
system_prompt = system_prompt or f"System prompt for {name}"
task_prompt = task_prompt or f"Task prompt for {name}"
persona_creation_request = {
"name": name,
"description": description,
"num_chunks": num_chunks,
"llm_relevance_filter": llm_relevance_filter,
"is_public": is_public,
"llm_filter_extraction": llm_filter_extraction,
"recency_bias": recency_bias,
"prompt_ids": prompt_ids or [0],
"document_set_ids": document_set_ids or [],
"tool_ids": tool_ids or [],
"llm_model_provider_override": llm_model_provider_override,
"llm_model_version_override": llm_model_version_override,
"users": users or [],
"groups": groups or [],
}
persona_creation_request = PersonaUpsertRequest(
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
prompt_ids=prompt_ids or [0],
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
users=[UUID(user) for user in (users or [])],
groups=groups or [],
label_ids=label_ids or [],
)
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request,
json=persona_creation_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -77,6 +88,7 @@ class PersonaManager:
llm_model_version_override=llm_model_version_override,
users=users or [],
groups=groups or [],
label_ids=label_ids or [],
)
@staticmethod
@@ -84,6 +96,9 @@ class PersonaManager:
persona: DATestPersona,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
include_citations: bool = False,
num_chunks: float | None = None,
llm_relevance_filter: bool | None = None,
is_public: bool | None = None,
@@ -96,32 +111,38 @@ class PersonaManager:
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
persona_update_request = {
"name": name or persona.name,
"description": description or persona.description,
"num_chunks": num_chunks or persona.num_chunks,
"llm_relevance_filter": llm_relevance_filter
or persona.llm_relevance_filter,
"is_public": is_public or persona.is_public,
"llm_filter_extraction": llm_filter_extraction
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
persona_update_request = PersonaUpsertRequest(
name=name or persona.name,
description=description or persona.description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
num_chunks=num_chunks or persona.num_chunks,
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
is_public=is_public or persona.is_public,
llm_filter_extraction=llm_filter_extraction
or persona.llm_filter_extraction,
"recency_bias": recency_bias or persona.recency_bias,
"prompt_ids": prompt_ids or persona.prompt_ids,
"document_set_ids": document_set_ids or persona.document_set_ids,
"tool_ids": tool_ids or persona.tool_ids,
"llm_model_provider_override": llm_model_provider_override
recency_bias=recency_bias or persona.recency_bias,
prompt_ids=prompt_ids or persona.prompt_ids,
document_set_ids=document_set_ids or persona.document_set_ids,
tool_ids=tool_ids or persona.tool_ids,
llm_model_provider_override=llm_model_provider_override
or persona.llm_model_provider_override,
"llm_model_version_override": llm_model_version_override
llm_model_version_override=llm_model_version_override
or persona.llm_model_version_override,
"users": users or persona.users,
"groups": groups or persona.groups,
}
users=[UUID(user) for user in (users or persona.users)],
groups=groups or persona.groups,
label_ids=label_ids or persona.label_ids,
)
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request,
json=persona_update_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -137,8 +158,8 @@ class PersonaManager:
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
is_public=updated_persona_data["is_public"],
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
recency_bias=updated_persona_data["recency_bias"],
prompt_ids=updated_persona_data["prompts"],
recency_bias=recency_bias or persona.recency_bias,
prompt_ids=[prompt["id"] for prompt in updated_persona_data["prompts"]],
document_set_ids=updated_persona_data["document_sets"],
tool_ids=updated_persona_data["tools"],
llm_model_provider_override=updated_persona_data[
@@ -149,6 +170,7 @@ class PersonaManager:
],
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
label_ids=updated_persona_data["labels"],
)
@staticmethod
@@ -164,12 +186,29 @@ class PersonaManager:
response.raise_for_status()
return [PersonaSnapshot(**persona) for persona in response.json()]
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**response.json())]
@staticmethod
def verify(
persona: DATestPersona,
user_performing_action: DATestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_all(user_performing_action)
all_personas = PersonaManager.get_one(
persona_id=persona.id,
user_performing_action=user_performing_action,
)
for fetched_persona in all_personas:
if fetched_persona.id == persona.id:
return (
@@ -199,6 +238,7 @@ class PersonaManager:
and set(user.email for user in fetched_persona.users)
== set(persona.users)
and set(fetched_persona.groups) == set(persona.groups)
and set(fetched_persona.labels) == set(persona.label_ids)
)
return False

View File

@@ -0,0 +1,84 @@
from datetime import datetime
from urllib.parse import urlencode
from uuid import UUID
import requests
from requests.models import CaseInsensitiveDict
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
class QueryHistoryManager:
@staticmethod
def get_query_history_page(
page_num: int = 0,
page_size: int = 10,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[ChatSessionMinimal]:
query_params: dict[str, str | int] = {
"page_num": page_num,
"page_size": page_size,
}
if feedback_type:
query_params["feedback_type"] = feedback_type.value
if start_time:
query_params["start_time"] = start_time.isoformat()
if end_time:
query_params["end_time"] = end_time.isoformat()
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return PaginatedReturn(
items=[ChatSessionMinimal(**item) for item in data["items"]],
total_items=data["total_items"],
)
@staticmethod
def get_chat_session_admin(
chat_session_id: UUID | str,
user_performing_action: DATestUser | None = None,
) -> ChatSessionSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return ChatSessionSnapshot(**response.json())
@staticmethod
def get_query_history_as_csv(
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> tuple[CaseInsensitiveDict[str], str]:
query_params: dict[str, str | int] = {}
if start_time:
query_params["start"] = start_time.isoformat()
if end_time:
query_params["end"] = end_time.isoformat()
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return response.headers, response.content.decode()

View File

@@ -213,17 +213,16 @@ class UserManager:
is_active_filter: bool | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[FullUserSnapshot]:
query_params = {
query_params: dict[str, str | list[str] | int] = {
"page_num": page_num,
"page_size": page_size,
"q": search_query if search_query else None,
"roles": [role.value for role in role_filter] if role_filter else None,
"is_active": is_active_filter if is_active_filter is not None else None,
}
# Remove None values
query_params = {
key: value for key, value in query_params.items() if value is not None
}
if search_query:
query_params["q"] = search_query
if role_filter:
query_params["roles"] = [role.value for role in role_filter]
if is_active_filter is not None:
query_params["is_active"] = is_active_filter
response = requests.get(
url=f"{API_SERVER_URL}/manage/users/accepted?{urlencode(query_params, doseq=True)}",

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
@@ -127,14 +128,7 @@ class DATestPersona(BaseModel):
llm_model_version_override: str | None
users: list[str]
groups: list[int]
category_id: int | None = None
#
class DATestChatSession(BaseModel):
id: UUID
persona_id: int
description: str
label_ids: list[int]
class DATestChatMessage(BaseModel):
@@ -144,6 +138,16 @@ class DATestChatMessage(BaseModel):
message: str
class DATestChatSession(BaseModel):
id: UUID
persona_id: int
description: str
class DAQueryHistoryEntry(DATestChatSession):
feedback_type: QAFeedbackType | None
class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None

View File

@@ -0,0 +1,175 @@
"""
This file tests the permissions for creating and editing personas for different user roles:
- Basic users can create personas and edit their own
- Curators can edit personas that belong exclusively to groups they curate
- Admins can edit all personas
"""
import pytest
from requests.exceptions import HTTPError
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.managers.user import DATestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_persona_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# Creating a curator user
curator: DATestUser = UserManager.create(name="curator")
# Creating a basic user
basic_user: DATestUser = UserManager.create(name="basic_user")
# Creating user groups
user_group_1 = UserGroupManager.create(
name="curated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# Setting the user as a curator for the user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Creating another user group that the user is not a curator of
user_group_2 = UserGroupManager.create(
name="uncurated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2], user_performing_action=admin_user
)
"""Test that any user can create a persona"""
# Basic user creates a persona
basic_user_persona = PersonaManager.create(
name="basic_user_persona",
description="A persona created by basic user",
is_public=False,
groups=[],
user_performing_action=basic_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
# Curator creates a persona
curator_persona = PersonaManager.create(
name="curator_persona",
description="A persona created by curator",
is_public=False,
groups=[],
user_performing_action=curator,
)
PersonaManager.verify(curator_persona, user_performing_action=curator)
# Admin creates personas for different groups
admin_persona_group_1 = PersonaManager.create(
name="admin_persona_group_1",
description="A persona for group 1",
is_public=False,
groups=[user_group_1.id],
user_performing_action=admin_user,
)
admin_persona_group_2 = PersonaManager.create(
name="admin_persona_group_2",
description="A persona for group 2",
is_public=False,
groups=[user_group_2.id],
user_performing_action=admin_user,
)
admin_persona_both_groups = PersonaManager.create(
name="admin_persona_both_groups",
description="A persona for both groups",
is_public=False,
groups=[user_group_1.id, user_group_2.id],
user_performing_action=admin_user,
)
"""Test that users can edit their own personas"""
# Basic user can edit their own persona
PersonaManager.edit(
persona=basic_user_persona,
description="Updated description by basic user",
user_performing_action=basic_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
# Basic user cannot edit other's personas
with pytest.raises(HTTPError):
PersonaManager.edit(
persona=curator_persona,
description="Invalid edit by basic user",
user_performing_action=basic_user,
)
"""Test curator permissions"""
# Curator can edit personas that belong exclusively to groups they curate
PersonaManager.edit(
persona=admin_persona_group_1,
description="Updated by curator",
user_performing_action=curator,
)
PersonaManager.verify(admin_persona_group_1, user_performing_action=curator)
# Curator cannot edit personas in groups they don't curate
with pytest.raises(HTTPError):
PersonaManager.edit(
persona=admin_persona_group_2,
description="Invalid edit by curator",
user_performing_action=curator,
)
# Curator cannot edit personas that belong to multiple groups, even if they curate one
with pytest.raises(HTTPError):
PersonaManager.edit(
persona=admin_persona_both_groups,
description="Invalid edit by curator",
user_performing_action=curator,
)
"""Test admin permissions"""
# Admin can edit any persona
PersonaManager.edit(
persona=basic_user_persona,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)
PersonaManager.edit(
persona=curator_persona,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(curator_persona, user_performing_action=admin_user)
PersonaManager.edit(
persona=admin_persona_group_1,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(admin_persona_group_1, user_performing_action=admin_user)
PersonaManager.edit(
persona=admin_persona_group_2,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(admin_persona_group_2, user_performing_action=admin_user)
PersonaManager.edit(
persona=admin_persona_both_groups,
description="Updated by admin",
user_performing_action=admin_user,
)
PersonaManager.verify(admin_persona_both_groups, user_performing_action=admin_user)

View File

@@ -3,16 +3,15 @@ from datetime import timedelta
from datetime import timezone
import pytest
import requests
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import SessionType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
@@ -69,66 +68,52 @@ def test_chat_history_endpoints(
) -> None:
admin_user, first_chat_id = setup_chat_session
response = requests.get(
f"{API_SERVER_URL}/admin/chat-session-history",
headers=admin_user.headers,
# Get chat history
history_response = QueryHistoryManager.get_query_history_page(
user_performing_action=admin_user
)
assert response.status_code == 200
history_response = response.json()
# Verify we got back the one chat session we created
assert len(history_response) == 1
assert len(history_response.items) == 1
# Verify the first chat session details
first_session = history_response[0]
assert first_session["user_email"] == admin_user.email
assert first_session["name"] == "Test chat session"
assert first_session["first_user_message"] == "What was the Q1 revenue?"
assert first_session["first_ai_message"] is not None
assert first_session["assistant_id"] == 0
assert first_session["feedback_type"] is None
assert first_session["flow_type"] == SessionType.CHAT.value
assert first_session["conversation_length"] == 4 # 2 User messages + 2 AI responses
first_session = history_response.items[0]
assert first_session.user_email == admin_user.email
assert first_session.name == "Test chat session"
assert first_session.first_user_message == "What was the Q1 revenue?"
assert first_session.first_ai_message is not None
assert first_session.assistant_id == 0
assert first_session.feedback_type is None
assert first_session.flow_type == SessionType.CHAT
assert first_session.conversation_length == 4 # 2 User messages + 2 AI responses
# Test date filtering - should return no results
past_end = datetime.now(tz=timezone.utc) - timedelta(days=1)
past_start = past_end - timedelta(days=1)
response = requests.get(
f"{API_SERVER_URL}/admin/chat-session-history",
params={
"start": past_start.isoformat(),
"end": past_end.isoformat(),
},
headers=admin_user.headers,
history_response = QueryHistoryManager.get_query_history_page(
start_time=past_start,
end_time=past_end,
user_performing_action=admin_user,
)
assert response.status_code == 200
history_response = response.json()
assert len(history_response) == 0
assert len(history_response.items) == 0
# Test get specific chat session endpoint
response = requests.get(
f"{API_SERVER_URL}/admin/chat-session-history/{first_chat_id}",
headers=admin_user.headers,
session_details = QueryHistoryManager.get_chat_session_admin(
chat_session_id=first_chat_id,
user_performing_action=admin_user,
)
assert response.status_code == 200
session_details = response.json()
# Verify the session details
assert session_details["id"] == first_chat_id
assert len(session_details["messages"]) > 0
assert session_details["flow_type"] == SessionType.CHAT.value
assert str(session_details.id) == first_chat_id
assert len(session_details.messages) > 0
assert session_details.flow_type == SessionType.CHAT
# Test filtering by feedback
response = requests.get(
f"{API_SERVER_URL}/admin/chat-session-history",
params={
"feedback_type": QAFeedbackType.LIKE.value,
},
headers=admin_user.headers,
history_response = QueryHistoryManager.get_query_history_page(
feedback_type=QAFeedbackType.LIKE,
user_performing_action=admin_user,
)
assert response.status_code == 200
history_response = response.json()
assert len(history_response) == 0
assert len(history_response.items) == 0
def test_chat_history_csv_export(
@@ -137,16 +122,13 @@ def test_chat_history_csv_export(
admin_user, _ = setup_chat_session
# Test CSV export endpoint with date filtering
response = requests.get(
f"{API_SERVER_URL}/admin/query-history-csv",
headers=admin_user.headers,
headers, csv_content = QueryHistoryManager.get_query_history_as_csv(
user_performing_action=admin_user,
)
assert response.status_code == 200
assert response.headers["Content-Type"] == "text/csv; charset=utf-8"
assert "Content-Disposition" in response.headers
assert headers["Content-Type"] == "text/csv; charset=utf-8"
assert "Content-Disposition" in headers
# Verify CSV content
csv_content = response.content.decode()
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 3 # Header + 2 QA pairs
assert "chat_session_id" in csv_content
@@ -158,15 +140,10 @@ def test_chat_history_csv_export(
# Test CSV export with date filtering - should return no results
past_end = datetime.now(tz=timezone.utc) - timedelta(days=1)
past_start = past_end - timedelta(days=1)
response = requests.get(
f"{API_SERVER_URL}/admin/query-history-csv",
params={
"start": past_start.isoformat(),
"end": past_end.isoformat(),
},
headers=admin_user.headers,
headers, csv_content = QueryHistoryManager.get_query_history_as_csv(
start_time=past_start,
end_time=past_end,
user_performing_action=admin_user,
)
assert response.status_code == 200
csv_content = response.content.decode()
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 1 # Only header, no data rows

View File

@@ -0,0 +1,112 @@
from datetime import datetime
from onyx.configs.constants import QAFeedbackType
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.query_history.utils import (
setup_chat_sessions_with_different_feedback,
)
def _verify_query_history_pagination(
chat_sessions: list[DAQueryHistoryEntry],
page_size: int = 5,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_sessions: list[str] = []
for i in range(0, len(chat_sessions), page_size):
paginated_result = QueryHistoryManager.get_query_history_page(
page_num=i // page_size,
page_size=page_size,
feedback_type=feedback_type,
start_time=start_time,
end_time=end_time,
user_performing_action=user_performing_action,
)
# Verify that the total items is equal to the length of the chat sessions list
assert paginated_result.total_items == len(chat_sessions)
# Verify that the number of items in the page is equal to the page size
assert len(paginated_result.items) == min(page_size, len(chat_sessions) - i)
# Add the retrieved chat sessions to the list of retrieved sessions
retrieved_sessions.extend(
[str(session.id) for session in paginated_result.items]
)
# Create a set of all the expected chat session IDs
all_expected_sessions = set(str(session.id) for session in chat_sessions)
# Create a set of all the retrieved chat session IDs
all_retrieved_sessions = set(retrieved_sessions)
# Verify that the set of retrieved sessions is equal to the set of expected sessions
assert all_expected_sessions == all_retrieved_sessions
def test_query_history_pagination(reset: None) -> None:
(
admin_user,
chat_sessions_by_feedback_type,
) = setup_chat_sessions_with_different_feedback()
all_chat_sessions = []
for _, chat_sessions in chat_sessions_by_feedback_type.items():
all_chat_sessions.extend(chat_sessions)
# Verify basic pagination with different page sizes
print("Verifying basic pagination with page size 5")
_verify_query_history_pagination(
chat_sessions=all_chat_sessions,
page_size=5,
user_performing_action=admin_user,
)
print("Verifying basic pagination with page size 10")
_verify_query_history_pagination(
chat_sessions=all_chat_sessions,
page_size=10,
user_performing_action=admin_user,
)
print("Verifying pagination with feedback type LIKE")
liked_sessions = chat_sessions_by_feedback_type[QAFeedbackType.LIKE]
_verify_query_history_pagination(
chat_sessions=liked_sessions,
feedback_type=QAFeedbackType.LIKE,
user_performing_action=admin_user,
)
print("Verifying pagination with feedback type DISLIKE")
disliked_sessions = chat_sessions_by_feedback_type[QAFeedbackType.DISLIKE]
_verify_query_history_pagination(
chat_sessions=disliked_sessions,
feedback_type=QAFeedbackType.DISLIKE,
user_performing_action=admin_user,
)
print("Verifying pagination with feedback type MIXED")
mixed_sessions = chat_sessions_by_feedback_type[QAFeedbackType.MIXED]
_verify_query_history_pagination(
chat_sessions=mixed_sessions,
feedback_type=QAFeedbackType.MIXED,
user_performing_action=admin_user,
)
# Test with a small page size to verify handling of partial pages
print("Verifying pagination with page size 3")
_verify_query_history_pagination(
chat_sessions=all_chat_sessions,
page_size=3,
user_performing_action=admin_user,
)
# Test with a page size larger than the total number of items
print("Verifying pagination with page size 50")
_verify_query_history_pagination(
chat_sessions=all_chat_sessions,
page_size=50,
user_performing_action=admin_user,
)

View File

@@ -0,0 +1,131 @@
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from onyx.configs.constants import QAFeedbackType
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
from tests.integration.common_utils.test_models import DATestUser
def _create_chat_session_with_feedback(
admin_user: DATestUser,
i: int,
feedback_type: QAFeedbackType | None,
) -> tuple[QAFeedbackType | None, DAQueryHistoryEntry]:
print(f"Creating chat session {i} with feedback type {feedback_type}")
# Create chat session with timestamp spread over 30 days
chat_session = ChatSessionManager.create(
persona_id=0,
description=f"Test chat session {i}",
user_performing_action=admin_user,
)
test_session = DAQueryHistoryEntry(
id=chat_session.id,
persona_id=0,
description=f"Test chat session {i}",
feedback_type=feedback_type,
)
# First message in chat
ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message=f"Question {i}?",
user_performing_action=admin_user,
)
messages = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
if feedback_type == QAFeedbackType.MIXED or feedback_type == QAFeedbackType.DISLIKE:
ChatSessionManager.create_chat_message_feedback(
message_id=messages[-1].id,
is_positive=False,
user_performing_action=admin_user,
)
# Second message with different feedback types
ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message=f"Follow up {i}?",
user_performing_action=admin_user,
parent_message_id=messages[-1].id,
)
# Get updated messages to get the ID of the second message
messages = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
if feedback_type == QAFeedbackType.MIXED or feedback_type == QAFeedbackType.LIKE:
ChatSessionManager.create_chat_message_feedback(
message_id=messages[-1].id,
is_positive=True,
user_performing_action=admin_user,
)
return feedback_type, test_session
def setup_chat_sessions_with_different_feedback() -> (
tuple[DATestUser, dict[QAFeedbackType | None, list[DAQueryHistoryEntry]]]
):
# Create admin user and required resources
admin_user: DATestUser = UserManager.create(name="admin_user")
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin_user)
api_key = APIKeyManager.create(user_performing_action=admin_user)
LLMProviderManager.create(user_performing_action=admin_user)
# Seed a document
cc_pair.documents = []
cc_pair.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair,
content="The company's revenue in Q1 was $1M",
api_key=api_key,
)
)
chat_sessions_by_feedback_type: dict[
QAFeedbackType | None, list[DAQueryHistoryEntry]
] = {}
# Use ThreadPoolExecutor to create chat sessions in parallel
with ThreadPoolExecutor(max_workers=5) as executor:
# Submit all tasks and store futures
j = 0
# Will result in 40 sessions
number_of_sessions = 10
futures = []
for feedback_type in [
QAFeedbackType.MIXED,
QAFeedbackType.LIKE,
QAFeedbackType.DISLIKE,
None,
]:
futures.extend(
[
executor.submit(
_create_chat_session_with_feedback,
admin_user,
(j * number_of_sessions) + i,
feedback_type,
)
for i in range(number_of_sessions)
]
)
j += 1
# Collect results in order
for future in as_completed(futures):
feedback_type, chat_session = future.result()
chat_sessions_by_feedback_type.setdefault(feedback_type, []).append(
chat_session
)
return admin_user, chat_sessions_by_feedback_type

View File

@@ -43,12 +43,6 @@ def _verify_user_pagination(
assert all_expected_emails == all_retrieved_emails
def _verify_user_role_and_status(users: list) -> None:
for user in users:
assert UserManager.is_role(user, user.role)
assert UserManager.is_status(user, user.is_active)
def test_user_pagination(reset: None) -> None:
# Create an admin user to perform actions
user_performing_action: DATestUser = UserManager.create(
@@ -108,7 +102,13 @@ def test_user_pagination(reset: None) -> None:
+ inactive_admins
+ searchable_curators
)
_verify_user_role_and_status(all_users)
for user in all_users:
# Verify that the user's role in the db matches
# the role in the user object
assert UserManager.is_role(user, user.role)
# Verify that the user's status in the db matches
# the status in the user object
assert UserManager.is_status(user, user.is_active)
# Verify pagination
_verify_user_pagination(

11
web/jest.config.js Normal file
View File

@@ -0,0 +1,11 @@
module.exports = {
preset: "ts-jest",
testEnvironment: "node",
moduleNameMapper: {
"^@/(.*)$": "<rootDir>/src/$1",
},
testPathIgnorePatterns: ["/node_modules/", "/tests/e2e/"],
transform: {
"^.+\\.tsx?$": "ts-jest",
},
};

2773
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,8 @@
"dev": "next dev --turbopack",
"build": "next build",
"start": "next start",
"lint": "next lint"
"lint": "next lint",
"test": "jest"
},
"dependencies": {
"@dnd-kit/core": "^6.1.0",
@@ -84,10 +85,13 @@
"@chromatic-com/playwright": "^0.10.0",
"@tailwindcss/typography": "^0.5.10",
"@types/chrome": "^0.0.287",
"@types/jest": "^29.5.14",
"chromatic": "^11.18.1",
"eslint": "^8.48.0",
"eslint-config-next": "^14.1.0",
"prettier": "2.8.8"
"jest": "^29.7.0",
"prettier": "2.8.8",
"ts-jest": "^29.2.5"
},
"overrides": {
"react-is": "^19.0.0-rc-69d4b800-20241021"

View File

@@ -6,14 +6,11 @@ import { generateRandomIconShape } from "@/lib/assistantIconUtils";
import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types";
import { Separator } from "@/components/ui/separator";
import { Button } from "@/components/ui/button";
import { Textarea } from "@/components/ui/textarea";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import { ArrayHelpers, FieldArray, Form, Formik, FormikProps } from "formik";
import {
BooleanFormField,
Label,
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
@@ -42,14 +39,10 @@ import { FiInfo, FiRefreshCcw, FiUsers } from "react-icons/fi";
import * as Yup from "yup";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import {
Persona,
PersonaLabel,
StarterMessage,
StarterMessageBase,
} from "./interfaces";
import { Persona, PersonaLabel, StarterMessage } from "./interfaces";
import {
createPersonaLabel,
PersonaUpsertParameters,
createPersona,
deletePersonaLabel,
updatePersonaLabel,
@@ -67,30 +60,19 @@ import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList";
import { LabelCard } from "./LabelCard";
import { Switch } from "@/components/ui/switch";
import { generateIdenticon } from "@/components/assistants/AssistantIcon";
import { BackButton } from "@/components/BackButton";
import { Checkbox } from "@/components/ui/checkbox";
import { Input } from "@/components/ui/input";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { AssistantVisibilityPopover } from "@/app/assistants/mine/AssistantVisibilityPopover";
import { MinimalUserSnapshot } from "@/lib/types";
import { useUserGroups } from "@/lib/hooks";
import { useUsers } from "@/lib/hooks";
import { AllUsersResponse } from "@/lib/types";
// import { Badge } from "@/components/ui/Badge";
// import {
// addUsersToAssistantSharedList,
// shareAssistantWithGroups,
// } from "@/lib/assistants/shareAssistant";
import {
SearchMultiSelectDropdown,
Option as DropdownOption,
} from "@/components/Dropdown";
import { Badge } from "@/components/ui/badge";
import { SourceChip } from "@/app/chat/input/ChatInputBar";
import { GroupIcon, TagIcon, UserIcon } from "lucide-react";
import { TagIcon, UserIcon } from "lucide-react";
import { LLMSelector } from "@/components/llm/LLMSelector";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
@@ -268,7 +250,6 @@ export function AssistantEditor({
labels: existingPersona?.labels ?? null,
// EE Only
groups: existingPersona?.groups ?? [],
label_ids: existingPersona?.labels?.map((label) => label.id) ?? [],
selectedUsers:
existingPersona?.users?.filter(
@@ -418,7 +399,6 @@ export function AssistantEditor({
icon_shape: Yup.number(),
uploaded_image: Yup.mixed().nullable(),
// EE Only
groups: Yup.array().of(Yup.number()),
label_ids: Yup.array().of(Yup.number()),
selectedUsers: Yup.array().of(Yup.object()),
selectedGroups: Yup.array().of(Yup.number()),
@@ -494,12 +474,13 @@ export function AssistantEditor({
}));
// don't set groups if marked as public
const groups = values.is_public ? [] : values.groups;
const submissionData = {
const groups = values.is_public ? [] : values.selectedGroups;
const submissionData: PersonaUpsertParameters = {
...values,
existing_prompt_id: existingPrompt?.id ?? null,
is_default_persona: admin!,
starter_messages: starterMessages,
groups: values.is_public ? [] : values.selectedGroups,
groups: groups,
users: values.is_public
? undefined
: [
@@ -514,25 +495,17 @@ export function AssistantEditor({
num_chunks: numChunks,
};
let promptResponse;
let personaResponse;
if (isUpdate) {
[promptResponse, personaResponse] = await updatePersona({
id: existingPersona.id,
existingPromptId: existingPrompt?.id,
...submissionData,
});
personaResponse = await updatePersona(
existingPersona.id,
submissionData
);
} else {
[promptResponse, personaResponse] = await createPersona({
...submissionData,
is_default_persona: admin!,
});
personaResponse = await createPersona(submissionData);
}
let error = null;
if (!promptResponse.ok) {
error = await promptResponse.text();
}
if (!personaResponse) {
error = "Failed to create Assistant - no response received";

View File

@@ -1,7 +1,7 @@
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { Persona, StarterMessage } from "./interfaces";
interface PersonaCreationRequest {
interface PersonaUpsertRequest {
name: string;
description: string;
system_prompt: string;
@@ -10,6 +10,36 @@ interface PersonaCreationRequest {
num_chunks: number | null;
include_citations: boolean;
is_public: boolean;
recency_bias: string;
prompt_ids: number[];
llm_filter_extraction: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
users?: string[];
groups: number[];
tool_ids: number[];
icon_color: string | null;
icon_shape: number | null;
remove_image?: boolean;
uploaded_image_id: string | null;
search_start_date: Date | null;
is_default_persona: boolean;
display_priority: number | null;
label_ids: number[] | null;
}
export interface PersonaUpsertParameters {
name: string;
description: string;
system_prompt: string;
existing_prompt_id: number | null;
task_prompt: string;
document_set_ids: number[];
num_chunks: number | null;
include_citations: boolean;
is_public: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
@@ -20,94 +50,10 @@ interface PersonaCreationRequest {
icon_color: string | null;
icon_shape: number | null;
remove_image?: boolean;
uploaded_image: File | null;
search_start_date: Date | null;
uploaded_image: File | null;
is_default_persona: boolean;
label_ids?: number[];
}
interface PersonaUpdateRequest {
id: number;
existingPromptId: number | undefined;
name: string;
description: string;
system_prompt: string;
task_prompt: string;
document_set_ids: number[];
num_chunks: number | null;
include_citations: boolean;
is_public: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
users?: string[];
groups: number[];
tool_ids: number[];
icon_color: string | null;
icon_shape: number | null;
remove_image: boolean;
uploaded_image: File | null;
search_start_date: Date | null;
label_ids?: number[];
}
function promptNameFromPersonaName(personaName: string) {
return `default-prompt__${personaName}`;
}
function createPrompt({
personaName,
systemPrompt,
taskPrompt,
includeCitations,
}: {
personaName: string;
systemPrompt: string;
taskPrompt: string;
includeCitations: boolean;
}) {
return fetch("/api/prompt", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
name: promptNameFromPersonaName(personaName),
description: `Default prompt for persona ${personaName}`,
system_prompt: systemPrompt,
task_prompt: taskPrompt,
include_citations: includeCitations,
}),
});
}
function updatePrompt({
promptId,
personaName,
systemPrompt,
taskPrompt,
includeCitations,
}: {
promptId: number;
personaName: string;
systemPrompt: string;
taskPrompt: string;
includeCitations: boolean;
}) {
return fetch(`/api/prompt/${promptId}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
name: promptNameFromPersonaName(personaName),
description: `Default prompt for persona ${personaName}`,
system_prompt: systemPrompt,
task_prompt: taskPrompt,
include_citations: includeCitations,
}),
});
label_ids: number[] | null;
}
export const createPersonaLabel = (name: string) => {
@@ -144,56 +90,57 @@ export const updatePersonaLabel = (
});
};
function buildPersonaAPIBody(
creationRequest: PersonaCreationRequest | PersonaUpdateRequest,
promptId: number,
function buildPersonaUpsertRequest(
creationRequest: PersonaUpsertParameters,
uploaded_image_id: string | null
) {
): PersonaUpsertRequest {
const {
name,
description,
system_prompt,
task_prompt,
document_set_ids,
num_chunks,
llm_relevance_filter,
include_citations,
is_public,
groups,
existing_prompt_id,
users,
tool_ids,
icon_color,
icon_shape,
remove_image,
search_start_date,
label_ids,
} = creationRequest;
const is_default_persona =
"is_default_persona" in creationRequest
? creationRequest.is_default_persona
: false;
return {
name,
description,
num_chunks,
llm_relevance_filter,
llm_filter_extraction: false,
is_public,
recency_bias: "base_decay",
prompt_ids: [promptId],
system_prompt,
task_prompt,
document_set_ids,
llm_model_provider_override: creationRequest.llm_model_provider_override,
llm_model_version_override: creationRequest.llm_model_version_override,
starter_messages: creationRequest.starter_messages,
users,
num_chunks,
include_citations,
is_public,
uploaded_image_id,
groups,
users,
tool_ids,
icon_color,
icon_shape,
uploaded_image_id,
remove_image,
search_start_date,
is_default_persona,
label_ids,
is_default_persona: creationRequest.is_default_persona ?? false,
recency_bias: "base_decay",
prompt_ids: existing_prompt_id ? [existing_prompt_id] : [],
llm_filter_extraction: false,
llm_relevance_filter: creationRequest.llm_relevance_filter ?? null,
llm_model_provider_override:
creationRequest.llm_model_provider_override ?? null,
llm_model_version_override:
creationRequest.llm_model_version_override ?? null,
starter_messages: creationRequest.starter_messages ?? null,
display_priority: null,
label_ids: creationRequest.label_ids ?? null,
};
}
@@ -215,92 +162,52 @@ export async function uploadFile(file: File): Promise<string | null> {
}
export async function createPersona(
personaCreationRequest: PersonaCreationRequest
): Promise<[Response, Response | null]> {
// first create prompt
const createPromptResponse = await createPrompt({
personaName: personaCreationRequest.name,
systemPrompt: personaCreationRequest.system_prompt,
taskPrompt: personaCreationRequest.task_prompt,
includeCitations: personaCreationRequest.include_citations,
});
const promptId = createPromptResponse.ok
? (await createPromptResponse.json()).id
: null;
personaUpsertParams: PersonaUpsertParameters
): Promise<Response | null> {
let fileId = null;
if (personaCreationRequest.uploaded_image) {
fileId = await uploadFile(personaCreationRequest.uploaded_image);
if (personaUpsertParams.uploaded_image) {
fileId = await uploadFile(personaUpsertParams.uploaded_image);
if (!fileId) {
return [createPromptResponse, null];
return null;
}
}
const createPersonaResponse =
promptId !== null
? await fetch("/api/persona", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaAPIBody(personaCreationRequest, promptId, fileId)
),
})
: null;
const createPersonaResponse = await fetch("/api/persona", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaUpsertRequest(personaUpsertParams, fileId)
),
});
return [createPromptResponse, createPersonaResponse];
return createPersonaResponse;
}
export async function updatePersona(
personaUpdateRequest: PersonaUpdateRequest
): Promise<[Response, Response | null]> {
const { id, existingPromptId } = personaUpdateRequest;
let promptResponse;
let promptId: number | null = null;
if (existingPromptId !== undefined) {
promptResponse = await updatePrompt({
promptId: existingPromptId,
personaName: personaUpdateRequest.name,
systemPrompt: personaUpdateRequest.system_prompt,
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = existingPromptId;
} else {
promptResponse = await createPrompt({
personaName: personaUpdateRequest.name,
systemPrompt: personaUpdateRequest.system_prompt,
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = promptResponse.ok
? ((await promptResponse.json()).id as number)
: null;
}
id: number,
personaUpsertParams: PersonaUpsertParameters
): Promise<Response | null> {
let fileId = null;
if (personaUpdateRequest.uploaded_image) {
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
if (personaUpsertParams.uploaded_image) {
fileId = await uploadFile(personaUpsertParams.uploaded_image);
if (!fileId) {
return [promptResponse, null];
return null;
}
}
const updatePersonaResponse =
promptResponse.ok && promptId !== null
? await fetch(`/api/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaAPIBody(personaUpdateRequest, promptId, fileId)
),
})
: null;
const updatePersonaResponse = await fetch(`/api/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaUpsertRequest(personaUpsertParams, fileId)
),
});
return [promptResponse, updatePersonaResponse];
return updatePersonaResponse;
}
export function deletePersona(personaId: number) {
@@ -309,25 +216,6 @@ export function deletePersona(personaId: number) {
});
}
export function buildFinalPrompt(
systemPrompt: string,
taskPrompt: string,
retrievalDisabled: boolean
) {
const queryString = Object.entries({
system_prompt: systemPrompt,
task_prompt: taskPrompt,
retrieval_disabled: retrievalDisabled,
})
.map(
([key, value]) =>
`${encodeURIComponent(key)}=${encodeURIComponent(value)}`
)
.join("&");
return fetch(`/api/persona/utils/prompt-explorer?${queryString}`);
}
function smallerNumberFirstComparator(a: number, b: number) {
return a > b ? 1 : -1;
}

View File

@@ -64,7 +64,13 @@ export const SlackTokensForm = ({
router.push(`/admin/bots/${encodeURIComponent(botId)}`);
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
let errorMsg = responseJson.detail || responseJson.message;
if (errorMsg.includes("Invalid bot token:")) {
errorMsg = "Slack Bot Token is invalid";
} else if (errorMsg.includes("Invalid app token:")) {
errorMsg = "Slack App Token is invalid";
}
setPopup({
message: isUpdate
? `Error updating Slack Bot - ${errorMsg}`

View File

@@ -1438,10 +1438,10 @@ export function ChatPage({
}
}
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
parentMessage =
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
const updateFn = (messages: Message[]) => {
const replacementsMap = regenerationRequest

View File

@@ -9,8 +9,6 @@ import {
} from "react-icons/fi";
import { FeedbackType } from "../types";
import React, {
memo,
ReactNode,
useCallback,
useContext,
useEffect,
@@ -21,7 +19,10 @@ import React, {
import ReactMarkdown from "react-markdown";
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
import { SearchSummary } from "./SearchSummary";
import {
markdownToHtml,
getMarkdownForSelection,
} from "@/app/chat/message/codeUtils";
import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton";
@@ -37,12 +38,10 @@ import { DocumentPreview } from "../files/documents/DocumentPreview";
import { InMessageImage } from "../files/images/InMessageImage";
import { CodeBlock } from "./CodeBlock";
import rehypePrism from "rehype-prism-plus";
import "prismjs/themes/prism-tomorrow.css";
import "./custom-code-styles.css";
import { Persona } from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons";
import {
CustomTooltip,
@@ -68,7 +67,6 @@ import CsvContent from "../../../components/tools/CSVContent";
import SourceCard, {
SeeMoreBlock,
} from "@/components/chat_search/sources/SourceCard";
import remarkMath from "remark-math";
import rehypeKatex from "rehype-katex";
import "katex/dist/katex.min.css";
@@ -373,15 +371,28 @@ export const AIMessage = ({
);
const renderedMarkdown = useMemo(() => {
if (typeof finalContent !== "string") {
return finalContent;
}
// Create a hidden div with the HTML content for copying
const htmlContent = markdownToHtml(finalContent);
return (
<ReactMarkdown
className="prose max-w-full text-base"
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
>
{finalContent as string}
</ReactMarkdown>
<>
<div
style={{ position: "absolute", left: "-9999px" }}
dangerouslySetInnerHTML={{ __html: htmlContent }}
/>
<ReactMarkdown
className="prose max-w-full text-base"
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
>
{finalContent}
</ReactMarkdown>
</>
);
}, [finalContent, markdownComponents]);
@@ -513,7 +524,68 @@ export const AIMessage = ({
{typeof content === "string" ? (
<div className="overflow-x-visible max-w-content-max">
{renderedMarkdown}
<div
contentEditable="true"
suppressContentEditableWarning
className="focus:outline-none cursor-text select-text"
style={{
MozUserModify: "read-only",
WebkitUserModify: "read-only",
}}
onCopy={(e) => {
e.preventDefault();
const selection = window.getSelection();
const selectedPlainText =
selection?.toString() || "";
if (!selectedPlainText) {
// If no text is selected, copy the full content
const contentStr =
typeof content === "string"
? content
: (
content as JSX.Element
).props?.children?.toString() || "";
const clipboardItem = new ClipboardItem({
"text/html": new Blob(
[
typeof content === "string"
? markdownToHtml(content)
: contentStr,
],
{ type: "text/html" }
),
"text/plain": new Blob([contentStr], {
type: "text/plain",
}),
});
navigator.clipboard.write([clipboardItem]);
return;
}
const contentStr =
typeof content === "string"
? content
: (
content as JSX.Element
).props?.children?.toString() || "";
const markdownText = getMarkdownForSelection(
contentStr,
selectedPlainText
);
const clipboardItem = new ClipboardItem({
"text/html": new Blob(
[markdownToHtml(markdownText)],
{ type: "text/html" }
),
"text/plain": new Blob([selectedPlainText], {
type: "text/plain",
}),
});
navigator.clipboard.write([clipboardItem]);
}}
>
{renderedMarkdown}
</div>
</div>
) : (
content
@@ -559,7 +631,16 @@ export const AIMessage = ({
)}
</div>
<CustomTooltip showTick line content="Copy">
<CopyButton content={content.toString()} />
<CopyButton
content={
typeof content === "string"
? {
html: markdownToHtml(content),
plainText: content,
}
: content.toString()
}
/>
</CustomTooltip>
<CustomTooltip showTick line content="Good response">
<HoverableIcon
@@ -644,7 +725,16 @@ export const AIMessage = ({
)}
</div>
<CustomTooltip showTick line content="Copy">
<CopyButton content={content.toString()} />
<CopyButton
content={
typeof content === "string"
? {
html: markdownToHtml(content),
plainText: content,
}
: content.toString()
}
/>
</CustomTooltip>
<CustomTooltip showTick line content="Good response">

View File

@@ -0,0 +1,176 @@
import { markdownToHtml, parseMarkdownToSegments } from "../codeUtils";
describe("markdownToHtml", () => {
test("converts bold text with asterisks and underscores", () => {
expect(markdownToHtml("This is **bold** text")).toBe(
"<p>This is <strong>bold</strong> text</p>"
);
expect(markdownToHtml("This is __bold__ text")).toBe(
"<p>This is <strong>bold</strong> text</p>"
);
});
test("converts italic text with asterisks and underscores", () => {
expect(markdownToHtml("This is *italic* text")).toBe(
"<p>This is <em>italic</em> text</p>"
);
expect(markdownToHtml("This is _italic_ text")).toBe(
"<p>This is <em>italic</em> text</p>"
);
});
test("handles mixed bold and italic", () => {
expect(markdownToHtml("This is **bold** and *italic* text")).toBe(
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
);
expect(markdownToHtml("This is __bold__ and _italic_ text")).toBe(
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
);
});
test("handles text with spaces and special characters", () => {
expect(markdownToHtml("This is *as delicious and* tasty")).toBe(
"<p>This is <em>as delicious and</em> tasty</p>"
);
expect(markdownToHtml("This is _as delicious and_ tasty")).toBe(
"<p>This is <em>as delicious and</em> tasty</p>"
);
});
test("handles multi-paragraph text with italics", () => {
const input =
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it.";
expect(markdownToHtml(input)).toBe(
"<p>Sure! Here is a sentence with one italicized word:</p>\n<p>The cake was <em>delicious</em> and everyone enjoyed it.</p>"
);
});
test("handles malformed markdown without crashing", () => {
expect(markdownToHtml("This is *malformed markdown")).toBe(
"<p>This is *malformed markdown</p>"
);
expect(markdownToHtml("This is _also malformed")).toBe(
"<p>This is _also malformed</p>"
);
expect(markdownToHtml("This has **unclosed bold")).toBe(
"<p>This has **unclosed bold</p>"
);
expect(markdownToHtml("This has __unclosed bold")).toBe(
"<p>This has __unclosed bold</p>"
);
});
test("handles empty or null input", () => {
expect(markdownToHtml("")).toBe("");
expect(markdownToHtml(" ")).toBe("");
expect(markdownToHtml("\n")).toBe("");
});
test("handles extremely long input without crashing", () => {
const longText = "This is *italic* ".repeat(1000);
expect(() => markdownToHtml(longText)).not.toThrow();
});
});
describe("parseMarkdownToSegments", () => {
test("parses italic text with asterisks", () => {
const segments = parseMarkdownToSegments("This is *italic* text");
expect(segments).toEqual([
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
{ type: "italic", text: "italic", raw: "*italic*", length: 6 },
{ type: "text", text: " text", raw: " text", length: 5 },
]);
});
test("parses italic text with underscores", () => {
const segments = parseMarkdownToSegments("This is _italic_ text");
expect(segments).toEqual([
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
{ type: "italic", text: "italic", raw: "_italic_", length: 6 },
{ type: "text", text: " text", raw: " text", length: 5 },
]);
});
test("parses bold text with asterisks", () => {
const segments = parseMarkdownToSegments("This is **bold** text");
expect(segments).toEqual([
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
{ type: "bold", text: "bold", raw: "**bold**", length: 4 },
{ type: "text", text: " text", raw: " text", length: 5 },
]);
});
test("parses bold text with underscores", () => {
const segments = parseMarkdownToSegments("This is __bold__ text");
expect(segments).toEqual([
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
{ type: "bold", text: "bold", raw: "__bold__", length: 4 },
{ type: "text", text: " text", raw: " text", length: 5 },
]);
});
test("parses text with spaces and special characters in italics", () => {
const segments = parseMarkdownToSegments(
"The cake was _delicious_ and everyone enjoyed it."
);
expect(segments).toEqual([
{ type: "text", text: "The cake was ", raw: "The cake was ", length: 13 },
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
{
type: "text",
text: " and everyone enjoyed it.",
raw: " and everyone enjoyed it.",
length: 25,
},
]);
});
test("parses multi-paragraph text with italics", () => {
const segments = parseMarkdownToSegments(
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it."
);
expect(segments).toEqual([
{
type: "text",
text: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
raw: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
length: 65,
},
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
{
type: "text",
text: " and everyone enjoyed it.",
raw: " and everyone enjoyed it.",
length: 25,
},
]);
});
test("handles malformed markdown without crashing", () => {
expect(() => parseMarkdownToSegments("This is *malformed")).not.toThrow();
expect(() =>
parseMarkdownToSegments("This is _also malformed")
).not.toThrow();
expect(() =>
parseMarkdownToSegments("This has **unclosed bold")
).not.toThrow();
expect(() =>
parseMarkdownToSegments("This has __unclosed bold")
).not.toThrow();
});
test("handles empty or null input", () => {
expect(parseMarkdownToSegments("")).toEqual([]);
expect(parseMarkdownToSegments(" ")).toEqual([
{ type: "text", text: " ", raw: " ", length: 1 },
]);
expect(parseMarkdownToSegments("\n")).toEqual([
{ type: "text", text: "\n", raw: "\n", length: 1 },
]);
});
test("handles extremely long input without crashing", () => {
const longText = "This is *italic* ".repeat(1000);
expect(() => parseMarkdownToSegments(longText)).not.toThrow();
});
});

View File

@@ -82,3 +82,252 @@ export const preprocessLaTeX = (content: string) => {
return inlineProcessedContent;
};
export const markdownToHtml = (content: string): string => {
if (!content || !content.trim()) {
return "";
}
// Basic markdown to HTML conversion for common patterns
const processedContent = content
.replace(/(\*\*|__)((?:(?!\1).)*?)\1/g, "<strong>$2</strong>") // Bold with ** or __, non-greedy and no nesting
.replace(/(\*|_)([^*_\n]+?)\1(?!\*|_)/g, "<em>$2</em>"); // Italic with * or _
// Handle code blocks and links
const withCodeAndLinks = processedContent
.replace(/`([^`]+)`/g, "<code>$1</code>") // Inline code
.replace(
/```(\w*)\n([\s\S]*?)```/g,
(_, lang, code) =>
`<pre><code class="language-${lang}">${code.trim()}</code></pre>`
) // Code blocks
.replace(/\[([^\]]+)\]\(([^)]+)\)/g, '<a href="$2">$1</a>'); // Links
// Handle paragraphs
return withCodeAndLinks
.split(/\n\n+/)
.map((para) => para.trim())
.filter((para) => para.length > 0)
.map((para) => `<p>${para}</p>`)
.join("\n");
};
interface MarkdownSegment {
type: "text" | "link" | "code" | "bold" | "italic" | "codeblock";
text: string; // The visible/plain text
raw: string; // The raw markdown including syntax
length: number; // Length of the visible text
}
export function parseMarkdownToSegments(markdown: string): MarkdownSegment[] {
if (!markdown) {
return [];
}
const segments: MarkdownSegment[] = [];
let currentIndex = 0;
const maxIterations = markdown.length * 2; // Prevent infinite loops
let iterations = 0;
while (currentIndex < markdown.length && iterations < maxIterations) {
iterations++;
let matched = false;
// Check for code blocks first (they take precedence)
const codeBlockMatch = markdown
.slice(currentIndex)
.match(/^```(\w*)\n([\s\S]*?)```/);
if (codeBlockMatch && codeBlockMatch[0]) {
const [fullMatch, , code] = codeBlockMatch;
segments.push({
type: "codeblock",
text: code || "",
raw: fullMatch,
length: (code || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for inline code
const inlineCodeMatch = markdown.slice(currentIndex).match(/^`([^`]+)`/);
if (inlineCodeMatch && inlineCodeMatch[0]) {
const [fullMatch, code] = inlineCodeMatch;
segments.push({
type: "code",
text: code || "",
raw: fullMatch,
length: (code || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for links
const linkMatch = markdown
.slice(currentIndex)
.match(/^\[([^\]]+)\]\(([^)]+)\)/);
if (linkMatch && linkMatch[0]) {
const [fullMatch, text] = linkMatch;
segments.push({
type: "link",
text: text || "",
raw: fullMatch,
length: (text || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for bold
const boldMatch = markdown
.slice(currentIndex)
.match(/^(\*\*|__)([^*_\n]*?)\1/);
if (boldMatch && boldMatch[0]) {
const [fullMatch, , text] = boldMatch;
segments.push({
type: "bold",
text: text || "",
raw: fullMatch,
length: (text || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// Check for italic
const italicMatch = markdown
.slice(currentIndex)
.match(/^(\*|_)([^*_\n]+?)\1(?!\*|_)/);
if (italicMatch && italicMatch[0]) {
const [fullMatch, , text] = italicMatch;
segments.push({
type: "italic",
text: text || "",
raw: fullMatch,
length: (text || "").length,
});
currentIndex += fullMatch.length;
matched = true;
continue;
}
// If no matches were found, handle regular text
if (!matched) {
let nextSpecialChar = markdown.slice(currentIndex).search(/[`\[*_]/);
if (nextSpecialChar === -1) {
// No more special characters, add the rest as text
const text = markdown.slice(currentIndex);
if (text) {
segments.push({
type: "text",
text: text,
raw: text,
length: text.length,
});
}
break;
} else {
// Add the text up to the next special character
const text = markdown.slice(
currentIndex,
currentIndex + nextSpecialChar
);
if (text) {
segments.push({
type: "text",
text: text,
raw: text,
length: text.length,
});
}
currentIndex += nextSpecialChar;
}
}
}
return segments;
}
export function getMarkdownForSelection(
content: string,
selectedText: string
): string {
const segments = parseMarkdownToSegments(content);
// Build plain text and create mapping to markdown segments
let plainText = "";
const markdownPieces: string[] = [];
let currentPlainIndex = 0;
segments.forEach((segment) => {
plainText += segment.text;
markdownPieces.push(segment.raw);
currentPlainIndex += segment.length;
});
// Find the selection in the plain text
const startIndex = plainText.indexOf(selectedText);
if (startIndex === -1) {
return selectedText;
}
const endIndex = startIndex + selectedText.length;
// Find which segments the selection spans
let currentIndex = 0;
let result = "";
let selectionStart = startIndex;
let selectionEnd = endIndex;
segments.forEach((segment) => {
const segmentStart = currentIndex;
const segmentEnd = segmentStart + segment.length;
// Check if this segment overlaps with the selection
if (segmentEnd > selectionStart && segmentStart < selectionEnd) {
// Calculate how much of this segment to include
const overlapStart = Math.max(0, selectionStart - segmentStart);
const overlapEnd = Math.min(segment.length, selectionEnd - segmentStart);
if (segment.type === "text") {
const textPortion = segment.text.slice(overlapStart, overlapEnd);
result += textPortion;
} else {
// For markdown elements, wrap just the selected portion with the appropriate markdown
const selectedPortion = segment.text.slice(overlapStart, overlapEnd);
switch (segment.type) {
case "bold":
result += `**${selectedPortion}**`;
break;
case "italic":
result += `*${selectedPortion}*`;
break;
case "code":
result += `\`${selectedPortion}\``;
break;
case "link":
// For links, we need to preserve the URL if it exists in the raw markdown
const urlMatch = segment.raw.match(/\]\((.*?)\)/);
const url = urlMatch ? urlMatch[1] : "";
result += `[${selectedPortion}](${url})`;
break;
case "codeblock":
result += `\`\`\`\n${selectedPortion}\n\`\`\``;
break;
default:
result += selectedPortion;
}
}
}
currentIndex += segment.length;
});
return result;
}

View File

@@ -10,7 +10,6 @@ import { cn } from "@/lib/utils";
import { CalendarIcon } from "lucide-react";
import { format } from "date-fns";
import { getXDaysAgo } from "./dateUtils";
import { Separator } from "@/components/ui/separator";
export const THIRTY_DAYS = "30d";
@@ -84,8 +83,16 @@ export const DateRangeSelector = memo(function DateRangeSelector({
defaultMonth={value?.from}
selected={value}
onSelect={(range) => {
if (range?.from && range?.to) {
onValueChange({ from: range.from, to: range.to });
if (range?.from) {
if (range.to) {
// Normal range selection when initialized with a range
onValueChange({ from: range.from, to: range.to });
} else {
// Single date selection when initilized without a range
const to = new Date(range.from);
const from = new Date(to.setDate(to.getDate() - 1));
onValueChange({ from, to });
}
}
}}
numberOfMonths={2}

View File

@@ -65,27 +65,6 @@ export const useOnyxBotAnalytics = (timeRange: DateRangePickerValue) => {
};
};
export const useQueryHistory = ({
selectedFeedbackType,
timeRange,
}: {
selectedFeedbackType: Feedback | null;
timeRange: DateRange;
}) => {
const url = buildApiPath("/api/admin/chat-session-history", {
feedback_type: selectedFeedbackType,
start: convertDateToStartOfDay(timeRange?.from)?.toISOString(),
end: convertDateToEndOfDay(timeRange?.to)?.toISOString(),
});
const swrResponse = useSWR<ChatSessionMinimal[]>(url, errorHandlingFetcher);
return {
...swrResponse,
refreshQueryHistory: () => mutate(url),
};
};
export function getDatesList(startDate: Date): string[] {
const datesList: string[] = [];
const endDate = new Date(); // current date

View File

@@ -1,4 +1,3 @@
import { useQueryHistory, useTimeRange } from "../lib";
import { Separator } from "@/components/ui/separator";
import {
Table,
@@ -20,8 +19,8 @@ import {
import { ThreeDotsLoader } from "@/components/Loading";
import { ChatSessionMinimal } from "../usage/types";
import { timestampToReadableDate } from "@/lib/dateUtils";
import { FiFrown, FiMinus, FiSmile } from "react-icons/fi";
import { useCallback, useState } from "react";
import { FiFrown, FiMinus, FiSmile, FiMeh } from "react-icons/fi";
import { useCallback, useState, useMemo } from "react";
import { Feedback } from "@/lib/types";
import { DateRange, DateRangeSelector } from "../DateRangeSelector";
import { PageSelector } from "@/components/PageSelector";
@@ -29,8 +28,11 @@ import Link from "next/link";
import { FeedbackBadge } from "./FeedbackBadge";
import { DownloadAsCSV } from "./DownloadAsCSV";
import CardSection from "@/components/admin/CardSection";
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
import { ErrorCallout } from "@/components/ErrorCallout";
const NUM_IN_PAGE = 20;
const ITEMS_PER_PAGE = 20;
const PAGES_PER_BATCH = 2;
function QueryHistoryTableRow({
chatSessionMinimal,
@@ -108,6 +110,12 @@ function SelectFeedbackType({
<span>Dislike</span>
</div>
</SelectItem>
<SelectItem value="mixed">
<div className="flex items-center gap-2">
<FiMeh className="h-4 w-4" />
<span>Mixed</span>
</div>
</SelectItem>
</SelectContent>
</Select>
</div>
@@ -116,31 +124,55 @@ function SelectFeedbackType({
}
export function QueryHistoryTable() {
const [selectedFeedbackType, setSelectedFeedbackType] = useState<
Feedback | "all"
>("all");
const [timeRange, setTimeRange] = useTimeRange();
const [dateRange, setDateRange] = useState<DateRange>(undefined);
const [filters, setFilters] = useState<{
feedback_type?: Feedback | "all";
start_time?: string;
end_time?: string;
}>({});
const { data: chatSessionData } = useQueryHistory({
selectedFeedbackType:
selectedFeedbackType === "all" ? null : selectedFeedbackType,
timeRange,
const {
currentPageData: chatSessionData,
isLoading,
error,
currentPage,
totalPages,
goToPage,
refresh,
} = usePaginatedFetch<ChatSessionMinimal>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: "/api/admin/chat-session-history",
filter: filters,
});
const [page, setPage] = useState(1);
const onTimeRangeChange = useCallback((value: DateRange) => {
setDateRange(value);
const onTimeRangeChange = useCallback(
(value: DateRange) => {
if (value) {
setTimeRange((prevTimeRange) => ({
...prevTimeRange,
from: new Date(value.from),
to: new Date(value.to),
}));
}
},
[setTimeRange]
);
if (value?.from && value?.to) {
setFilters((prev) => ({
...prev,
start_time: value.from.toISOString(),
end_time: value.to.toISOString(),
}));
} else {
setFilters((prev) => {
const newFilters = { ...prev };
delete newFilters.start_time;
delete newFilters.end_time;
return newFilters;
});
}
}, []);
if (error) {
return (
<ErrorCallout
errorTitle="Error fetching query history"
errorMsg={error?.message}
/>
);
}
return (
<CardSection className="mt-8">
@@ -148,12 +180,22 @@ export function QueryHistoryTable() {
<div className="flex">
<div className="gap-y-3 flex flex-col">
<SelectFeedbackType
value={selectedFeedbackType || "all"}
onValueChange={setSelectedFeedbackType}
value={filters.feedback_type || "all"}
onValueChange={(value) => {
setFilters((prev) => {
const newFilters = { ...prev };
if (value === "all") {
delete newFilters.feedback_type;
} else {
newFilters.feedback_type = value;
}
return newFilters;
});
}}
/>
<DateRangeSelector
value={timeRange}
value={dateRange}
onValueChange={onTimeRangeChange}
/>
</div>
@@ -172,33 +214,33 @@ export function QueryHistoryTable() {
<TableHead>Date</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{chatSessionData &&
chatSessionData
.slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page)
.map((chatSessionMinimal) => (
<QueryHistoryTableRow
key={chatSessionMinimal.id}
chatSessionMinimal={chatSessionMinimal}
/>
))}
</TableBody>
{isLoading ? (
<TableBody>
<TableRow>
<TableCell colSpan={6} className="text-center">
<ThreeDotsLoader />
</TableCell>
</TableRow>
</TableBody>
) : (
<TableBody>
{chatSessionData?.map((chatSessionMinimal) => (
<QueryHistoryTableRow
key={chatSessionMinimal.id}
chatSessionMinimal={chatSessionMinimal}
/>
))}
</TableBody>
)}
</Table>
{chatSessionData && (
<div className="mt-3 flex">
<div className="mx-auto">
<PageSelector
totalPages={Math.ceil(chatSessionData.length / NUM_IN_PAGE)}
currentPage={page}
onPageChange={(newPage) => {
setPage(newPage);
window.scrollTo({
top: 0,
left: 0,
behavior: "smooth",
});
}}
totalPages={totalPages}
currentPage={currentPage}
onPageChange={goToPage}
/>
</div>
</div>

View File

@@ -106,9 +106,7 @@ export default function QueryPage(props: { params: Promise<{ id: string }> }) {
<div className="flex flex-col">
{chatSessionSnapshot.messages.map((message) => {
return (
<MessageDisplay key={message.time_created} message={message} />
);
return <MessageDisplay key={message.id} message={message} />;
})}
</div>
</CardSection>

View File

@@ -25,6 +25,7 @@ export interface AbridgedSearchDoc {
}
export interface MessageSnapshot {
id: number;
message: string;
message_type: "user" | "assistant";
documents: AbridgedSearchDoc[];

View File

@@ -1,23 +1,45 @@
import { useState } from "react";
import { FiCheck, FiCopy } from "react-icons/fi";
import { Hoverable, HoverableIcon } from "./Hoverable";
import { HoverableIcon } from "./Hoverable";
import { CheckmarkIcon, CopyMessageIcon } from "./icons/icons";
export function CopyButton({
content,
onClick,
}: {
content?: string;
content?: string | { html: string; plainText: string };
onClick?: () => void;
}) {
const [copyClicked, setCopyClicked] = useState(false);
const copyToClipboard = async (
content: string | { html: string; plainText: string }
) => {
try {
const clipboardItem = new ClipboardItem({
"text/html": new Blob(
[typeof content === "string" ? content : content.html],
{ type: "text/html" }
),
"text/plain": new Blob(
[typeof content === "string" ? content : content.plainText],
{ type: "text/plain" }
),
});
await navigator.clipboard.write([clipboardItem]);
} catch (err) {
// Fallback to basic text copy if HTML copy fails
await navigator.clipboard.writeText(
typeof content === "string" ? content : content.plainText
);
}
};
return (
<HoverableIcon
icon={copyClicked ? <CheckmarkIcon /> : <CopyMessageIcon />}
onClick={() => {
if (content) {
navigator.clipboard.writeText(content.toString());
copyToClipboard(content);
}
onClick && onClick();

View File

@@ -16,17 +16,19 @@ interface TextViewProps {
presentingDocument: OnyxDocument;
onClose: () => void;
}
export default function TextView({
presentingDocument,
onClose,
}: TextViewProps) {
const [zoom, setZoom] = useState(100);
const [fileContent, setFileContent] = useState<string>("");
const [fileUrl, setFileUrl] = useState<string>("");
const [fileName, setFileName] = useState<string>("");
const [fileContent, setFileContent] = useState("");
const [fileUrl, setFileUrl] = useState("");
const [fileName, setFileName] = useState("");
const [isLoading, setIsLoading] = useState(true);
const [fileType, setFileType] = useState<string>("application/octet-stream");
const [fileType, setFileType] = useState("application/octet-stream");
// Detect if a given MIME type is one of the recognized markdown formats
const isMarkdownFormat = (mimeType: string): boolean => {
const markdownFormats = [
"text/markdown",
@@ -38,6 +40,7 @@ export default function TextView({
return markdownFormats.some((format) => mimeType.startsWith(format));
};
// Detect if a given MIME type can be rendered in an <iframe>
const isSupportedIframeFormat = (mimeType: string): boolean => {
const supportedFormats = [
"application/pdf",
@@ -52,6 +55,7 @@ export default function TextView({
const fetchFile = useCallback(async () => {
setIsLoading(true);
const fileId = presentingDocument.document_id.split("__")[1];
try {
const response = await fetch(
`/api/chat/file/${encodeURIComponent(fileId)}`,
@@ -62,18 +66,33 @@ export default function TextView({
const blob = await response.blob();
const url = window.URL.createObjectURL(blob);
setFileUrl(url);
setFileName(presentingDocument.semantic_identifier || "document");
const contentType =
const originalFileName =
presentingDocument.semantic_identifier || "document";
setFileName(originalFileName);
let contentType =
response.headers.get("Content-Type") || "application/octet-stream";
// If it's octet-stream but file name suggests a markdown extension, override and attempt to read as markdown
if (
contentType === "application/octet-stream" &&
(originalFileName.toLowerCase().endsWith(".md") ||
originalFileName.toLowerCase().endsWith(".markdown"))
) {
contentType = "text/markdown";
}
setFileType(contentType);
if (isMarkdownFormat(blob.type)) {
// If the final content type looks like markdown, read its text
if (isMarkdownFormat(contentType)) {
const text = await blob.text();
setFileContent(text);
}
} catch (error) {
console.error("Error fetching file:", error);
} finally {
// Keep the slight delay for a smoother loading experience
setTimeout(() => {
setIsLoading(false);
}, 1000);
@@ -97,11 +116,8 @@ export default function TextView({
const handleZoomOut = () => setZoom((prev) => Math.max(prev - 25, 100));
return (
<Dialog open={true} onOpenChange={onClose}>
<DialogContent
hideCloseIcon
className="max-w-5xl w-[90vw] flex flex-col justify-between gap-y-0 h-full max-h-[80vh] p-0"
>
<Dialog open onOpenChange={onClose}>
<DialogContent className="max-w-5xl w-[90vw] flex flex-col justify-between gap-y-0 h-full max-h-[80vh] p-0">
<DialogHeader className="px-4 mb-0 pt-2 pb-3 flex flex-row items-center justify-between border-b">
<DialogTitle className="text-lg font-medium truncate">
{fileName}
@@ -120,12 +136,13 @@ export default function TextView({
<Download className="h-4 w-4" />
<span className="sr-only">Download</span>
</Button>
<Button variant="ghost" size="icon" onClick={() => onClose()}>
<Button variant="ghost" size="icon" onClick={onClose}>
<XIcon className="h-4 w-4" />
<span className="sr-only">Close</span>
</Button>
</div>
</DialogHeader>
<div className="mt-0 rounded-b-lg flex-1 overflow-hidden">
<div className="flex items-center justify-center w-full h-full">
{isLoading ? (
@@ -137,7 +154,7 @@ export default function TextView({
</div>
) : (
<div
className={`w-full h-full transform origin-center transition-transform duration-300 ease-in-out`}
className="w-full h-full transform origin-center transition-transform duration-300 ease-in-out"
style={{ transform: `scale(${zoom / 100})` }}
>
{isSupportedIframeFormat(fileType) ? (
@@ -150,7 +167,7 @@ export default function TextView({
<div className="w-full h-full p-6 overflow-y-scroll overflow-x-hidden">
<MinimalMarkdown
content={fileContent}
className="w-full pb-4 h-full text-lg text-wrap break-words"
className="w-full pb-4 h-full text-lg break-words"
/>
</div>
) : (

View File

@@ -5,12 +5,14 @@ import {
AcceptedUserSnapshot,
InvitedUserSnapshot,
} from "@/lib/types";
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
import { errorHandlingFetcher } from "@/lib/fetcher";
type PaginatedType =
| IndexAttemptSnapshot
| AcceptedUserSnapshot
| InvitedUserSnapshot;
| InvitedUserSnapshot
| ChatSessionMinimal;
interface PaginatedApiResponse<T extends PaginatedType> {
items: T[];
@@ -22,7 +24,7 @@ interface PaginationConfig {
pagesPerBatch: number;
endpoint: string;
query?: string;
filter?: Record<string, string | boolean | number | string[]>;
filter?: Record<string, string | boolean | number | string[] | Date>;
refreshIntervalInMs?: number;
}

View File

@@ -98,7 +98,7 @@ export type ValidStatuses =
| "in_progress"
| "not_started";
export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE";
export type Feedback = "like" | "dislike";
export type Feedback = "like" | "dislike" | "mixed";
export type AccessType = "public" | "private" | "sync";
export type SessionType = "Chat" | "Search" | "Slack";