1
0
forked from github/onyx

Compare commits

...

2 Commits

Author SHA1 Message Date
trial-danswer
8ee907e8ec Fix 2025-04-20 19:40:16 -07:00
trial-danswer
c343182a55 Paginated responses for Assistants 2025-04-19 19:17:43 -07:00
12 changed files with 777 additions and 164 deletions

View File

@@ -8,6 +8,7 @@ from onyx.server.documents.models import ConnectorSnapshot
from onyx.server.documents.models import CredentialSnapshot
from onyx.server.features.document_set.models import DocumentSet
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.utils import build_persona_snapshot
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPreferences
@@ -67,7 +68,7 @@ class UserGroup(BaseModel):
DocumentSet.from_model(ds) for ds in user_group_model.document_sets
],
personas=[
PersonaSnapshot.from_model(persona)
build_persona_snapshot(persona)
for persona in user_group_model.personas
if not persona.deleted
],

View File

@@ -7,6 +7,7 @@ from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
@@ -35,9 +36,11 @@ from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.utils import build_persona_snapshot
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -268,7 +271,7 @@ def create_update_persona(
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
return build_persona_snapshot(persona)
def update_persona_shared_users(
@@ -316,6 +319,83 @@ def update_persona_public_status(
db_session.commit()
def count_personas_for_user(
user: User | None,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
joinedload_all: bool = False,
persona_ids: list[int] | None = None,
is_public: bool | None = None,
is_pinned: bool | None = None,
is_users: bool | None = None,
name_matches: str | None = None,
is_image_generation_available: bool = False,
has_any_connectors: bool | None = None,
has_image_compatible_model: bool | None = None,
) -> int:
stmt = select(Persona)
stmt = _add_user_filters(stmt, user, get_editable)
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
if persona_ids:
stmt = stmt.where(Persona.id.in_(persona_ids))
# Filter out personas with unavailable tools
if (
user
and is_image_generation_available is not None
and not is_image_generation_available
):
stmt = stmt.where(
not_(Persona.tools.any(Tool.in_code_tool_id == "ImageGenerationTool"))
)
# Filter by public/private
if is_public is not None:
stmt = stmt.where(Persona.is_public == is_public)
# Filter by search
if name_matches is not None:
stmt = stmt.where(Persona.name.ilike(f"%{name_matches}%"))
# Filter by ownership
if user and is_users is not None:
stmt = stmt.where(Persona.user_id == user.id)
# Filter by pinned
if user and is_pinned is not None:
stmt = stmt.where(Persona.id.in_(user.pinned_assistants or []))
if has_any_connectors is not None and not has_any_connectors:
stmt = stmt.where(
or_(
Persona.num_chunks == 0,
Persona.document_sets.any(DocumentSet.id == Persona.id),
)
)
if has_image_compatible_model is not None and not has_image_compatible_model:
stmt = stmt.where(
not_(Persona.tools.any(Tool.in_code_tool_id == "ImageGenerationTool"))
)
if joinedload_all:
stmt = stmt.options(
selectinload(Persona.prompts),
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.groups),
selectinload(Persona.users),
selectinload(Persona.labels),
)
stmt = select(func.count()).select_from(Persona)
return db_session.execute(stmt).scalar_one()
def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
@@ -325,6 +405,16 @@ def get_personas_for_user(
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
joinedload_all: bool = False,
persona_ids: list[int] | None = None,
page_num: int = 0,
page_size: int = 100,
is_public: bool | None = None,
is_pinned: bool | None = None,
is_users: bool | None = None,
name_matches: str | None = None,
is_image_generation_available: bool = False,
has_any_connectors: bool | None = None,
has_image_compatible_model: bool | None = None,
) -> Sequence[Persona]:
stmt = select(Persona)
stmt = _add_user_filters(stmt, user, get_editable)
@@ -335,6 +425,44 @@ def get_personas_for_user(
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
if persona_ids:
stmt = stmt.where(Persona.id.in_(persona_ids))
# Filter out personas with unavailable tools
if (
user
and is_image_generation_available is not None
and not is_image_generation_available
):
stmt = stmt.where(
not_(Persona.tools.any(Tool.in_code_tool_id == "ImageGenerationTool"))
)
# Filter by public/private
if is_public is not None:
stmt = stmt.where(Persona.is_public == is_public)
# Filter by search
if name_matches is not None:
stmt = stmt.where(Persona.name.ilike(f"%{name_matches}%"))
# Filter by ownership
if user and is_users is not None:
stmt = stmt.where(Persona.user_id == user.id)
# Filter by pinned
if user and is_pinned is not None:
stmt = stmt.where(Persona.id.in_(user.pinned_assistants or []))
if has_any_connectors is not None and not has_any_connectors:
stmt = stmt.where(
or_(
Persona.num_chunks == 0,
Persona.document_sets.any(DocumentSet.id == Persona.id),
)
)
if has_image_compatible_model is not None and not has_image_compatible_model:
stmt = stmt.where(
not_(Persona.tools.any(Tool.in_code_tool_id == "ImageGenerationTool"))
)
stmt = stmt.offset(page_num * page_size).limit(page_size)
if joinedload_all:
stmt = stmt.options(
@@ -350,6 +478,64 @@ def get_personas_for_user(
return results
def get_paginated_personas_for_user(
user: User | None,
db_session: Session,
get_editable: bool = True,
include_deleted: bool = False,
page_num: int = 0,
page_size: int = 100,
persona_ids: list[int] | None = None,
joinedload_all: bool = False,
is_public: bool | None = None,
is_pinned: bool | None = None,
is_users: bool | None = None,
name_matches: str | None = None,
is_image_generation_available: bool | None = None,
has_any_connectors: bool | None = None,
has_image_compatible_model: bool | None = None,
) -> PaginatedReturn[PersonaSnapshot]:
total_items = count_personas_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
include_deleted=include_deleted,
joinedload_all=joinedload_all,
persona_ids=persona_ids,
is_public=is_public,
is_pinned=is_pinned,
is_users=is_users,
name_matches=name_matches,
is_image_generation_available=is_image_generation_available,
has_any_connectors=has_any_connectors,
has_image_compatible_model=has_image_compatible_model,
)
return PaginatedReturn(
items=[
build_persona_snapshot(persona)
for persona in get_personas_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
include_deleted=include_deleted,
joinedload_all=joinedload_all,
page_num=page_num,
page_size=page_size,
persona_ids=persona_ids,
is_public=is_public,
is_pinned=is_pinned,
is_users=is_users,
name_matches=name_matches,
is_image_generation_available=is_image_generation_available,
has_any_connectors=has_any_connectors,
has_image_compatible_model=has_image_compatible_model,
)
],
total_items=total_items,
)
def get_personas(db_session: Session) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from uuid import UUID
@@ -25,6 +26,9 @@ from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.utils import mask_credential_dict
if TYPE_CHECKING:
from onyx.server.features.persona.models import PersonaSnapshot
class DocumentSyncStatus(BaseModel):
doc_id: str
@@ -194,6 +198,7 @@ PaginatedType = TypeVar(
InvitedUserSnapshot,
ChatSessionMinimal,
IndexAttemptErrorPydantic,
"PersonaSnapshot",
)

View File

@@ -26,8 +26,8 @@ from onyx.db.persona import create_assistant_label
from onyx.db.persona import create_update_persona
from onyx.db.persona import delete_persona_label
from onyx.db.persona import get_assistant_labels
from onyx.db.persona import get_paginated_personas_for_user
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import get_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_all_personas_display_priority
@@ -43,6 +43,7 @@ from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
@@ -51,6 +52,7 @@ from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.models import PromptSnapshot
from onyx.server.features.persona.utils import build_persona_snapshot
from onyx.server.models import DisplayPriorityRequest
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
@@ -147,17 +149,26 @@ def list_personas_admin(
db_session: Session = Depends(get_session),
include_deleted: bool = False,
get_editable: bool = Query(False, description="If true, return editable personas"),
) -> list[PersonaSnapshot]:
return [
PersonaSnapshot.from_model(persona)
for persona in get_personas_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
include_deleted=include_deleted,
joinedload_all=True,
)
]
page_num: int = Query(0, ge=0),
page_size: int = Query(100, ge=1, le=1000),
is_public: bool | None = None,
is_pinned: bool | None = None,
is_users: bool | None = None,
name_matches: str | None = None,
) -> PaginatedReturn[PersonaSnapshot]:
return get_paginated_personas_for_user(
user=user,
db_session=db_session,
get_editable=get_editable,
include_deleted=include_deleted,
joinedload_all=True,
page_num=page_num,
page_size=page_size,
is_public=is_public,
is_pinned=is_pinned,
is_users=is_users,
name_matches=name_matches,
)
@admin_router.patch("/{persona_id}/undelete")
@@ -394,30 +405,35 @@ def list_personas(
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),
) -> list[PersonaSnapshot]:
personas = get_personas_for_user(
page_num: int = Query(0, ge=0),
page_size: int = Query(100, ge=1, le=1000),
is_public: bool | None = None,
is_pinned: bool | None = None,
is_users: bool | None = None,
name_matches: str | None = None,
has_any_connectors: bool | None = None,
has_image_compatible_model: bool | None = None,
) -> PaginatedReturn[PersonaSnapshot]:
return get_paginated_personas_for_user(
user=user,
include_deleted=include_deleted,
db_session=db_session,
include_deleted=include_deleted,
get_editable=False,
joinedload_all=True,
persona_ids=persona_ids,
page_num=page_num,
page_size=page_size,
is_public=is_public,
is_pinned=is_pinned,
is_users=is_users,
name_matches=name_matches,
is_image_generation_available=is_image_generation_available(
db_session=db_session
),
has_any_connectors=has_any_connectors,
has_image_compatible_model=has_image_compatible_model,
)
if persona_ids:
personas = [p for p in personas if p.id in persona_ids]
# Filter out personas with unavailable tools
personas = [
p
for p in personas
if not (
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
and not is_image_generation_available(db_session=db_session)
)
]
return [PersonaSnapshot.from_model(p) for p in personas]
@basic_router.get("/{persona_id}")
def get_persona(
@@ -425,7 +441,7 @@ def get_persona(
user: User | None = Depends(current_limited_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(
return build_persona_snapshot(
get_persona_by_id(
persona_id=persona_id,
user=user,

View File

@@ -5,7 +5,6 @@ from pydantic import BaseModel
from pydantic import Field
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.models import Persona
from onyx.db.models import PersonaLabel
from onyx.db.models import Prompt
from onyx.db.models import StarterMessage
@@ -114,55 +113,6 @@ class PersonaSnapshot(BaseModel):
search_start_date: datetime | None = None
labels: list["PersonaLabelSnapshot"] = []
@classmethod
def from_model(
cls, persona: Persona, allow_deleted: bool = False
) -> "PersonaSnapshot":
if persona.deleted:
error_msg = f"Persona with ID {persona.id} has been deleted"
if not allow_deleted:
raise ValueError(error_msg)
else:
logger.warning(error_msg)
return PersonaSnapshot(
id=persona.id,
name=persona.name,
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
is_visible=persona.is_visible,
is_public=persona.is_public,
display_priority=persona.display_priority,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
builtin_persona=persona.builtin_persona,
is_default_persona=persona.is_default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
icon_color=persona.icon_color,
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
)
class PromptTemplateResponse(BaseModel):
final_prompt_template: str

View File

@@ -0,0 +1,59 @@
from onyx.db.models import Persona
from onyx.server.features.document_set.models import DocumentSet
from onyx.server.features.persona.models import PersonaLabelSnapshot
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PromptSnapshot
from onyx.server.features.tool.models import ToolSnapshot
from onyx.server.models import MinimalUserSnapshot
from onyx.utils.logger import setup_logger
logger = setup_logger()
def build_persona_snapshot(
persona: Persona, allow_deleted: bool = False
) -> PersonaSnapshot:
if persona.deleted:
error_msg = f"Persona with ID {persona.id} has been deleted"
if not allow_deleted:
raise ValueError(error_msg)
else:
logger.warning(error_msg)
return PersonaSnapshot(
id=persona.id,
name=persona.name,
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
is_visible=persona.is_visible,
is_public=persona.is_public,
display_priority=persona.display_priority,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
builtin_persona=persona.builtin_persona,
is_default_persona=persona.is_default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
users=[
MinimalUserSnapshot(id=user.id, email=user.email) for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
icon_color=persona.icon_color,
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
)

View File

@@ -20,6 +20,7 @@ from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
from onyx.db.models import User
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.utils import build_persona_snapshot
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
@@ -245,7 +246,7 @@ class SlackChannelConfig(BaseModel):
id=slack_channel_config_model.id,
slack_bot_id=slack_channel_config_model.slack_bot_id,
persona=(
PersonaSnapshot.from_model(
build_persona_snapshot(
slack_channel_config_model.persona, allow_deleted=True
)
if slack_channel_config_model.persona

View File

@@ -1,13 +1,20 @@
"use client";
import React, { useMemo, useState } from "react";
import React, { useEffect, useMemo, useRef, useState } from "react";
import { useRouter } from "next/navigation";
import AssistantCard from "./AssistantCard";
import { useAssistants } from "@/components/context/AssistantsContext";
import { useUser } from "@/components/user/UserProvider";
import { FilterIcon, XIcon } from "lucide-react";
import { checkUserOwnsAssistant } from "@/lib/assistants/checkOwnership";
import {
useFilteredAssistants,
usePrefetchedAllAssistants,
usePrefetchedPinnedAssistants,
usePrefetchedPrivateAssistants,
usePrefetchedPublicAssistants,
usePrefetchedUsersAssistants,
} from "@/hooks/assistants/usePrefetchedAssistants";
import { Persona } from "@/app/admin/assistants/interfaces";
export const AssistantBadgeSelector = ({
text,
selected,
@@ -59,56 +66,185 @@ const useAssistantFilter = () => {
return { assistantFilters, toggleAssistantFilter, setAssistantFilters };
};
const useDebounce = (value: string, delay: number) => {
const [debouncedValue, setDebouncedValue] = useState(value);
useEffect(() => {
const timeoutId = setTimeout(() => setDebouncedValue(value), delay);
return () => clearTimeout(timeoutId);
}, [value, delay]);
return debouncedValue;
};
interface AssistantModalProps {
hideModal: () => void;
}
export function AssistantModal({ hideModal }: AssistantModalProps) {
const { assistants, pinnedAssistants } = useAssistants();
const { allAssistants, goToAllPage, currentAllPage, totalAllPages } =
usePrefetchedAllAssistants();
const [assistants, setAssistants] = useState<Persona[]>([]);
const [currentPage, setCurrentPage] = useState(0);
const [totalPages, setTotalPages] = useState(0);
const {
pinnedAssistants,
goToPinnedPage,
currentPinnedPage,
totalPinnedPages,
} = usePrefetchedPinnedAssistants();
const {
publicAssistants,
goToPublicPage,
currentPublicPage,
totalPublicPages,
} = usePrefetchedPublicAssistants();
const {
privateAssistants,
goToPrivatePage,
currentPrivatePage,
totalPrivatePages,
} = usePrefetchedPrivateAssistants();
const { usersAssistants, goToUsersPage, currentUsersPage, totalUsersPages } =
usePrefetchedUsersAssistants();
const { assistantFilters, toggleAssistantFilter } = useAssistantFilter();
console.log(assistantFilters);
const router = useRouter();
const { user } = useUser();
const [searchQuery, setSearchQuery] = useState("");
const debouncedSearchQuery = useDebounce(searchQuery, 300);
const [isSearchFocused, setIsSearchFocused] = useState(false);
const containerRef = useRef<HTMLDivElement>(null);
// const [activeAsistantFetchType, setActiveAsistantFetchType] = useState("base");
const memoizedCurrentlyVisibleAssistants = useMemo(() => {
return assistants.filter((assistant) => {
const nameMatches = assistant.name
.toLowerCase()
.includes(searchQuery.toLowerCase());
const labelMatches = assistant.labels?.some((label) =>
label.name.toLowerCase().includes(searchQuery.toLowerCase())
);
const publicFilter =
!assistantFilters[AssistantFilter.Public] || assistant.is_public;
const privateFilter =
!assistantFilters[AssistantFilter.Private] || !assistant.is_public;
const pinnedFilter =
!assistantFilters[AssistantFilter.Pinned] ||
(pinnedAssistants.map((a) => a.id).includes(assistant.id) ?? false);
const activeFilters = Object.entries(assistantFilters)
.filter(([_, value]) => value)
.map(([key]) => key as AssistantFilter);
const mineFilter =
!assistantFilters[AssistantFilter.Mine] ||
checkUserOwnsAssistant(user, assistant);
const {
filteredAssistants,
goToFilteredPage,
currentFilteredPage,
totalFilteredPages,
} = useFilteredAssistants(activeFilters, debouncedSearchQuery);
console.log(debouncedSearchQuery);
return (
(nameMatches || labelMatches) &&
publicFilter &&
privateFilter &&
pinnedFilter &&
mineFilter
);
});
}, [assistants, searchQuery, assistantFilters]);
const fetchAssistants = async (
debouncedSearchQuery: string,
assistantFilters: Record<AssistantFilter, boolean>,
page: number | 0
) => {
const active = Object.entries(assistantFilters).filter(
([_, value]) => value
);
if (active.length === 0 && debouncedSearchQuery.length === 0) {
setAssistants(allAssistants || []);
setCurrentPage(currentAllPage);
setTotalPages(totalAllPages);
if (page > 0) {
goToAllPage(page);
}
}
if (active.length === 1 && debouncedSearchQuery.length === 0) {
if (active[0][0] === AssistantFilter.Pinned) {
setAssistants(pinnedAssistants || []);
setCurrentPage(currentPinnedPage);
setTotalPages(totalPinnedPages);
if (page > 0) {
goToPinnedPage(page);
}
}
if (active[0][0] === AssistantFilter.Public) {
setAssistants(publicAssistants || []);
setCurrentPage(currentPublicPage);
setTotalPages(totalPublicPages);
if (page > 0) {
goToPublicPage(page);
}
}
if (active[0][0] === AssistantFilter.Private) {
setAssistants(privateAssistants || []);
setCurrentPage(currentPrivatePage);
setTotalPages(totalPrivatePages);
if (page > 0) {
goToPrivatePage(page);
}
}
if (active[0][0] === AssistantFilter.Mine) {
setAssistants(usersAssistants || []);
setCurrentPage(currentUsersPage);
setTotalPages(totalUsersPages);
if (page > 0) {
goToUsersPage(page);
}
}
} else {
if (active.length >= 2) {
if (
active.includes([AssistantFilter.Private, true]) &&
active.includes([AssistantFilter.Public, true])
) {
setAssistants([]);
}
}
setAssistants(filteredAssistants || []);
setCurrentPage(currentFilteredPage);
setTotalPages(totalFilteredPages);
if (page > 0) {
goToFilteredPage(page);
}
}
};
const featuredAssistants = [
...memoizedCurrentlyVisibleAssistants.filter(
(assistant) => assistant.is_default_persona
),
];
const allAssistants = memoizedCurrentlyVisibleAssistants.filter(
(assistant) => !assistant.is_default_persona
);
let finalAssistants = filteredAssistants;
if (searchQuery.length > 0) {
finalAssistants = filteredAssistants;
} else if (assistantFilters[AssistantFilter.Mine]) {
finalAssistants = usersAssistants;
} else if (assistantFilters[AssistantFilter.Private]) {
finalAssistants = privateAssistants;
} else if (assistantFilters[AssistantFilter.Public]) {
finalAssistants = publicAssistants;
} else if (assistantFilters[AssistantFilter.Pinned]) {
finalAssistants = pinnedAssistants;
}
useEffect(() => {
fetchAssistants(debouncedSearchQuery, assistantFilters, 0);
}, [debouncedSearchQuery, assistantFilters]);
useEffect(() => {
if (currentPage > 0) {
fetchAssistants(debouncedSearchQuery, assistantFilters, currentPage);
}
}, [currentPage]);
const featuredAssistants =
finalAssistants?.filter((assistant) => assistant.is_default_persona) ?? [];
const completeAssistants =
finalAssistants?.filter((assistant) => !assistant.is_default_persona) ?? [];
useEffect(() => {
const handleScroll = () => {
if (containerRef.current) {
const bottom =
containerRef.current.scrollHeight ===
containerRef.current.scrollTop + containerRef.current.clientHeight;
console.log("currentPage", currentPage);
console.log("totalPages", totalPages);
if (bottom && currentPage < totalPages - 1) {
setCurrentPage(currentPage + 1);
}
}
};
const el = containerRef.current;
el?.addEventListener("scroll", handleScroll);
return () => el?.removeEventListener("scroll", handleScroll);
}, [currentPage, assistants]);
return (
<div
@@ -221,9 +357,11 @@ export function AssistantModal({ hideModal }: AssistantModalProps) {
featuredAssistants.map((assistant, index) => (
<div key={index}>
<AssistantCard
pinned={pinnedAssistants
.map((a) => a.id)
.includes(assistant.id)}
pinned={
user?.preferences?.pinned_assistants?.includes(
assistant.id
) ?? false
}
persona={assistant}
closeModal={hideModal}
/>
@@ -236,14 +374,14 @@ export function AssistantModal({ hideModal }: AssistantModalProps) {
)}
</div>
{allAssistants && allAssistants.length > 0 && (
{completeAssistants && completeAssistants.length > 0 && (
<>
<h2 className="text-2xl font-semibold text-text-800 mt-4 mb-2 px-4 py-2">
All Assistants
</h2>
<div className="w-full mt-2 px-2 pb-2 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-6">
{allAssistants
{completeAssistants
.sort((a, b) => b.id - a.id)
.map((assistant, index) => (
<div key={index}>

View File

@@ -16,6 +16,11 @@ import {
filterAssistants,
} from "@/lib/assistants/utils";
import { useUser } from "../user/UserProvider";
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
import {
usePrefetchedAdminAssistants,
usePrefetchedFilteredAssistants,
} from "@/hooks/assistants/usePrefetchedAssistants";
interface AssistantsContextProps {
assistants: Persona[];
@@ -30,6 +35,14 @@ interface AssistantsContextProps {
allAssistants: Persona[];
pinnedAssistants: Persona[];
setPinnedAssistants: Dispatch<SetStateAction<Persona[]>>;
isLoadingEditablePersonas: boolean;
currentEditablePage: number;
totalEditablePages: number;
goToEditablePage: (page: number) => void;
isLoadingAllPersonas: boolean;
currentAllPage: number;
totalAllPages: number;
goToAllPage: (page: number) => void;
}
const AssistantsContext = createContext<AssistantsContextProps | undefined>(
@@ -53,7 +66,20 @@ export const AssistantsProvider: React.FC<{
const { user, isAdmin, isCurator } = useUser();
const [editablePersonas, setEditablePersonas] = useState<Persona[]>([]);
const [allAssistants, setAllAssistants] = useState<Persona[]>([]);
const [isLoadingEditablePersonas, setIsLoadingEditablePersonas] =
useState<boolean>(false);
const [currentEditablePage, setCurrentEditablePage] = useState<number>(1);
const [totalEditablePages, setTotalEditablePages] = useState<number>(1);
const [goToEditablePage, setGoToEditablePage] = useState<
(page: number) => void
>(() => {});
const [isLoadingAllPersonas, setIsLoadingAllPersonas] =
useState<boolean>(false);
const [currentAllPage, setCurrentAllPage] = useState<number>(1);
const [totalAllPages, setTotalAllPages] = useState<number>(1);
const [goToAllPage, setGoToAllPage] = useState<(page: number) => void>(
() => {}
);
const [pinnedAssistants, setPinnedAssistants] = useState<Persona[]>(() => {
if (user?.preferences.pinned_assistants) {
return user.preferences.pinned_assistants
@@ -95,27 +121,52 @@ export const AssistantsProvider: React.FC<{
checkImageGenerationAvailability();
}, []);
const {
adminAssistants: editableAdminAssistants,
isLoadingAdminAssistants: isLoadingEditableAdminAssistants,
currentAdminPage: currentEditableAdminPage,
totalAdminPages: totalEditableAdminPages,
goToAdminPage: goToEditableAdminPage,
} = usePrefetchedAdminAssistants(true);
const {
adminAssistants: allAdminAssistants,
isLoadingAdminAssistants: isLoadingAllAdminAssistants,
currentAdminPage: currentAllAdminPage,
totalAdminPages: totalAllAdminPages,
goToAdminPage: goToAllAdminPage,
} = usePrefetchedAdminAssistants(false);
const {
filteredAssistants: filteredAssistants,
isLoadingFilteredAssistants: isLoadingFilteredAssistants,
currentFilteredPage: currentFilteredPage,
totalFilteredPages: totalFilteredPages,
goToFilteredPage: goToFilteredPage,
} = usePrefetchedFilteredAssistants(
hasAnyConnectors,
hasImageCompatibleModel
);
const fetchPersonas = async () => {
if (!isAdmin && !isCurator) {
return;
}
try {
const [editableResponse, allResponse] = await Promise.all([
fetch("/api/admin/persona?get_editable=true"),
fetch("/api/admin/persona"),
]);
if (editableResponse.ok) {
const editablePersonas = await editableResponse.json();
setEditablePersonas(editablePersonas);
if (editableAdminAssistants) {
setEditablePersonas(editableAdminAssistants);
setIsLoadingEditablePersonas(isLoadingEditableAdminAssistants);
setCurrentEditablePage(currentEditableAdminPage);
setTotalEditablePages(totalEditableAdminPages);
setGoToEditablePage(goToEditableAdminPage);
}
if (allResponse.ok) {
const allPersonas = await allResponse.json();
setAllAssistants(allPersonas);
} else {
console.error("Error fetching personas:", allResponse);
if (allAdminAssistants) {
setAllAssistants(allAdminAssistants);
setIsLoadingAllPersonas(isLoadingAllAdminAssistants);
setCurrentAllPage(currentAllAdminPage);
setTotalAllPages(totalAllAdminPages);
setGoToAllPage(goToAllAdminPage);
}
} catch (error) {
console.error("Error fetching personas:", error);
@@ -128,22 +179,7 @@ export const AssistantsProvider: React.FC<{
const refreshAssistants = async () => {
try {
const response = await fetch("/api/persona", {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) throw new Error("Failed to fetch assistants");
let assistants: Persona[] = await response.json();
let filteredAssistants = filterAssistants(
assistants,
hasAnyConnectors,
hasImageCompatibleModel
);
setAssistants(filteredAssistants);
setAssistants(filteredAssistants || []);
// Fetch and update allAssistants for admins and curators
await fetchPersonas();
@@ -190,7 +226,15 @@ export const AssistantsProvider: React.FC<{
ownedButHiddenAssistants,
refreshAssistants,
editablePersonas,
isLoadingEditablePersonas,
currentEditablePage,
totalEditablePages,
goToEditablePage,
allAssistants,
isLoadingAllPersonas,
currentAllPage,
totalAllPages,
goToAllPage,
isImageGenerationAvailable,
setPinnedAssistants,
pinnedAssistants,

View File

@@ -0,0 +1,210 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import usePaginatedFetch from "../usePaginatedFetch";
import { AssistantFilter } from "@/app/assistants/mine/AssistantModal";
const ITEMS_PER_PAGE = 100;
const PAGES_PER_BATCH = 1;
const AssistantFilterName = {
Pinned: "is_pinned=true",
Public: "is_public=true",
Private: "is_public=false",
Mine: "is_users=true",
};
export const usePrefetchedPublicAssistants = () => {
const {
currentPageData: publicAssistants,
isLoading: isLoadingPublicAssistants,
currentPage: currentPublicPage,
totalPages: totalPublicPages,
goToPage: goToPublicPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: "/api/persona?is_public=true",
refreshIntervalInMs: undefined,
});
return {
publicAssistants,
isLoadingPublicAssistants,
currentPublicPage,
totalPublicPages,
goToPublicPage,
};
};
export const usePrefetchedPrivateAssistants = () => {
const {
currentPageData: privateAssistants,
isLoading: isLoadingPrivateAssistants,
currentPage: currentPrivatePage,
totalPages: totalPrivatePages,
goToPage: goToPrivatePage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: "/api/persona?is_public=false",
refreshIntervalInMs: undefined,
});
return {
privateAssistants,
isLoadingPrivateAssistants,
currentPrivatePage,
totalPrivatePages,
goToPrivatePage,
};
};
export const usePrefetchedPinnedAssistants = () => {
const {
currentPageData: pinnedAssistants,
isLoading: isLoadingPinnedAssistants,
currentPage: currentPinnedPage,
totalPages: totalPinnedPages,
goToPage: goToPinnedPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: "/api/persona?is_pinned=true",
refreshIntervalInMs: undefined,
});
return {
pinnedAssistants,
isLoadingPinnedAssistants,
currentPinnedPage,
totalPinnedPages,
goToPinnedPage,
};
};
export const usePrefetchedUsersAssistants = () => {
const {
currentPageData: usersAssistants,
isLoading: isLoadingUsersAssistants,
currentPage: currentUsersPage,
totalPages: totalUsersPages,
goToPage: goToUsersPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: "/api/persona?is_users=true",
});
return {
usersAssistants,
isLoadingUsersAssistants,
currentUsersPage,
totalUsersPages,
goToUsersPage,
};
};
export const useFilteredAssistants = (
assistantFilters: AssistantFilter[],
searchQuery: string
) => {
const totalQueryParts = [];
assistantFilters.forEach((filter) => {
totalQueryParts.push(AssistantFilterName[filter]);
});
if (searchQuery.length > 0) {
totalQueryParts.push(`name_matches=${searchQuery}`);
}
console.log(totalQueryParts);
const totalQuery = totalQueryParts.join("&");
const {
currentPageData: filteredAssistants,
isLoading: isLoadingFilteredAssistants,
currentPage: currentFilteredPage,
totalPages: totalFilteredPages,
goToPage: goToFilteredPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: `/api/persona?${totalQuery}`,
refreshIntervalInMs: undefined,
});
return {
filteredAssistants,
isLoadingFilteredAssistants,
currentFilteredPage,
totalFilteredPages,
goToFilteredPage,
};
};
export const usePrefetchedAdminAssistants = (editable: boolean) => {
const {
currentPageData: adminAssistants,
isLoading: isLoadingAdminAssistants,
currentPage: currentAdminPage,
totalPages: totalAdminPages,
goToPage: goToAdminPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: `/api/admin/persona${editable ? "?get_editable=true" : ""}`,
refreshIntervalInMs: undefined,
});
return {
adminAssistants,
isLoadingAdminAssistants,
currentAdminPage,
totalAdminPages,
goToAdminPage,
};
};
export const usePrefetchedFilteredAssistants = (
hasAnyConnectors: boolean,
hasImageCompatibleModel: boolean
) => {
const {
currentPageData: filteredAssistants,
isLoading: isLoadingFilteredAssistants,
currentPage: currentFilteredPage,
totalPages: totalFilteredPages,
goToPage: goToFilteredPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: `/api/persona?has_any_connectors=${hasAnyConnectors}&has_image_compatible_model=${hasImageCompatibleModel}`,
refreshIntervalInMs: undefined,
});
return {
filteredAssistants,
isLoadingFilteredAssistants,
currentFilteredPage,
totalFilteredPages,
goToFilteredPage,
};
};
export const usePrefetchedAllAssistants = () => {
const {
currentPageData: allAssistants,
isLoading: isLoadingAllAssistants,
currentPage: currentAllPage,
totalPages: totalAllPages,
goToPage: goToAllPage,
} = usePaginatedFetch<Persona>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: "/api/persona",
refreshIntervalInMs: undefined,
});
return {
allAssistants,
isLoadingAllAssistants,
currentAllPage,
totalAllPages,
goToAllPage,
};
};

View File

@@ -106,7 +106,7 @@ function usePaginatedFetch<T extends PaginatedType>({
}
}
const url = `${endpoint}?${params.toString()}`;
const url = `${endpoint}${endpoint.includes("?") ? `&` : "?"}${params.toString()}`;
const responseData =
await errorHandlingFetcher<PaginatedApiResponse<T>>(url);
@@ -151,9 +151,12 @@ function usePaginatedFetch<T extends PaginatedType>({
if (currentPath) {
const params = new URLSearchParams(searchParams);
params.set("page", page.toString());
router.replace(`${currentPath}?${params.toString()}`, {
scroll: false,
});
router.replace(
`${currentPath}${currentPath.includes("?") ? `&` : "?"}${params.toString()}`,
{
scroll: false,
}
);
}
},
[currentPath, router, searchParams]

View File

@@ -6,7 +6,7 @@ export type FetchAssistantsResponse = [Persona[], string | null];
export async function fetchAssistantsSS(): Promise<FetchAssistantsResponse> {
const response = await fetchSS("/persona");
if (response.ok) {
return [(await response.json()) as Persona[], null];
return [(await response.json()).items as Persona[], null];
}
return [[], (await response.json()).detail || "Unknown Error"];
}