mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-24 00:52:47 +00:00
Compare commits
11 Commits
bo/query_p
...
v1.6.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7868ce2c6 | ||
|
|
07c6acf034 | ||
|
|
bbc2c6d7d0 | ||
|
|
47fa208034 | ||
|
|
40d3facf62 | ||
|
|
7eb6a1d861 | ||
|
|
2dd610a358 | ||
|
|
ede7bcb662 | ||
|
|
26c53b6e66 | ||
|
|
d4164f4ac9 | ||
|
|
6629e354d6 |
@@ -0,0 +1,38 @@
|
||||
"""Adding assistant-specific user preferences
|
||||
|
||||
Revision ID: b329d00a9ea6
|
||||
Revises: bd7c3bf8beba
|
||||
Create Date: 2025-08-26 23:14:44.592985
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b329d00a9ea6"
|
||||
down_revision = "bd7c3bf8beba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"assistant__user_specific_config",
|
||||
sa.Column("assistant_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("disabled_tool_ids", postgresql.ARRAY(sa.Integer()), nullable=False),
|
||||
sa.ForeignKeyConstraint(["assistant_id"], ["persona.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("assistant_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("assistant__user_specific_config")
|
||||
@@ -144,6 +144,21 @@ def _get_force_search_settings(
|
||||
tools: list[Tool],
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None,
|
||||
) -> ForceUseTool:
|
||||
if new_msg_req.forced_tool_ids:
|
||||
forced_tools = [
|
||||
tool for tool in tools if tool.id in new_msg_req.forced_tool_ids
|
||||
]
|
||||
if not forced_tools:
|
||||
raise ValueError(
|
||||
f"No tools found for forced tool IDs: {new_msg_req.forced_tool_ids}"
|
||||
)
|
||||
return ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=forced_tools[0].name,
|
||||
args=None,
|
||||
override_kwargs=None,
|
||||
)
|
||||
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
@@ -603,6 +618,7 @@ def stream_chat_message_objects(
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
|
||||
@@ -2701,6 +2701,20 @@ class PersonaLabel(Base):
|
||||
)
|
||||
|
||||
|
||||
class Assistant__UserSpecificConfig(Base):
|
||||
__tablename__ = "assistant__user_specific_config"
|
||||
|
||||
assistant_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
disabled_tool_ids: Mapped[list[int]] = mapped_column(
|
||||
postgresql.ARRAY(Integer), nullable=False
|
||||
)
|
||||
|
||||
|
||||
AllowedAnswerFilters = (
|
||||
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
|
||||
)
|
||||
|
||||
198
backend/onyx/db/user_preferences.py
Normal file
198
backend/onyx/db/user_preferences.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database."""
|
||||
user.role = new_role
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def deactivate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Deactivate a user by setting is_active to False."""
|
||||
user.is_active = False
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True."""
|
||||
user.is_active = True
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_latest_access_token_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> AccessToken | None:
|
||||
"""Get the most recent access token for a user."""
|
||||
try:
|
||||
result = db_session.execute(
|
||||
select(AccessToken)
|
||||
.where(AccessToken.user_id == user_id) # type: ignore
|
||||
.order_by(desc(Column("created_at")))
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching AccessToken: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def update_user_temperature_override_enabled(
|
||||
user_id: UUID,
|
||||
temperature_override_enabled: bool,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's temperature override enabled setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(temperature_override_enabled=temperature_override_enabled)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_shortcut_enabled(
|
||||
user_id: UUID,
|
||||
shortcut_enabled: bool,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's shortcut enabled setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(shortcut_enabled=shortcut_enabled)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_auto_scroll(
|
||||
user_id: UUID,
|
||||
auto_scroll: bool | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's auto scroll setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(auto_scroll=auto_scroll)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_default_model(
|
||||
user_id: UUID,
|
||||
default_model: str | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's default model setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(default_model=default_model)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_pinned_assistants(
|
||||
user_id: UUID,
|
||||
pinned_assistants: list[int],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's pinned assistants list."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(pinned_assistants=pinned_assistants)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_assistant_visibility(
|
||||
user_id: UUID,
|
||||
hidden_assistants: list[int] | None,
|
||||
visible_assistants: list[int] | None,
|
||||
chosen_assistants: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's assistant visibility settings."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(
|
||||
hidden_assistants=hidden_assistants,
|
||||
visible_assistants=visible_assistants,
|
||||
chosen_assistants=chosen_assistants,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_all_user_assistant_specific_configs(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> Sequence[Assistant__UserSpecificConfig]:
|
||||
"""Get the full user assistant specific config for a specific assistant and user."""
|
||||
return db_session.scalars(
|
||||
select(Assistant__UserSpecificConfig).where(
|
||||
Assistant__UserSpecificConfig.user_id == user_id
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def update_assistant_preferences(
|
||||
assistant_id: int,
|
||||
user_id: UUID,
|
||||
new_assistant_preference: UserSpecificAssistantPreference,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update the disabled tools for a specific assistant for a specific user."""
|
||||
# First check if a config already exists
|
||||
result = db_session.execute(
|
||||
select(Assistant__UserSpecificConfig)
|
||||
.where(Assistant__UserSpecificConfig.assistant_id == assistant_id)
|
||||
.where(Assistant__UserSpecificConfig.user_id == user_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# Update existing config
|
||||
config.disabled_tool_ids = new_assistant_preference.disabled_tool_ids
|
||||
else:
|
||||
# Create new config
|
||||
config = Assistant__UserSpecificConfig(
|
||||
assistant_id=assistant_id,
|
||||
user_id=user_id,
|
||||
disabled_tool_ids=new_assistant_preference.disabled_tool_ids,
|
||||
)
|
||||
db_session.add(config)
|
||||
|
||||
db_session.commit()
|
||||
@@ -44,6 +44,13 @@ class AuthTypeResponse(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
|
||||
class UserSpecificAssistantPreference(BaseModel):
|
||||
disabled_tool_ids: list[int]
|
||||
|
||||
|
||||
UserSpecificAssistantPreferences = dict[int, UserSpecificAssistantPreference]
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
chosen_assistants: list[int] | None = None
|
||||
hidden_assistants: list[int] = []
|
||||
@@ -56,6 +63,9 @@ class UserPreferences(BaseModel):
|
||||
auto_scroll: bool | None = None
|
||||
temperature_override_enabled: bool | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
|
||||
class TenantSnapshot(BaseModel):
|
||||
tenant_id: str
|
||||
@@ -94,6 +104,7 @@ class UserInfo(BaseModel):
|
||||
team_name: str | None = None,
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -113,6 +124,7 @@ class UserInfo(BaseModel):
|
||||
visible_assistants=user.visible_assistants,
|
||||
auto_scroll=user.auto_scroll,
|
||||
temperature_override_enabled=user.temperature_override_enabled,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
team_name=team_name,
|
||||
|
||||
@@ -15,10 +15,6 @@ from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.email_utils import send_user_email_invite
|
||||
@@ -46,8 +42,19 @@ from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import User
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.user_preferences import deactivate_user
|
||||
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
|
||||
from onyx.db.user_preferences import get_latest_access_token_for_user
|
||||
from onyx.db.user_preferences import update_assistant_preferences
|
||||
from onyx.db.user_preferences import update_user_assistant_visibility
|
||||
from onyx.db.user_preferences import update_user_auto_scroll
|
||||
from onyx.db.user_preferences import update_user_default_model
|
||||
from onyx.db.user_preferences import update_user_pinned_assistants
|
||||
from onyx.db.user_preferences import update_user_role
|
||||
from onyx.db.user_preferences import update_user_shortcut_enabled
|
||||
from onyx.db.user_preferences import update_user_temperature_override_enabled
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.db.users import get_page_of_filtered_users
|
||||
@@ -66,6 +73,8 @@ from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.manage.models import UserPreferences
|
||||
from onyx.server.manage.models import UserRoleResponse
|
||||
from onyx.server.manage.models import UserRoleUpdateRequest
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreferences
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
@@ -121,9 +130,7 @@ def set_user_role(
|
||||
"remove_curator_status__no_commit",
|
||||
)(db_session, user_to_update)
|
||||
|
||||
user_to_update.role = user_role_update_request.new_role
|
||||
|
||||
db_session.commit()
|
||||
update_user_role(user_to_update, requested_role, db_session)
|
||||
|
||||
|
||||
class TestUpsertRequest(BaseModel):
|
||||
@@ -390,7 +397,7 @@ def remove_invited_user(
|
||||
|
||||
|
||||
@router.patch("/manage/admin/deactivate-user")
|
||||
def deactivate_user(
|
||||
def deactivate_user_api(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -413,9 +420,7 @@ def deactivate_user(
|
||||
if user_to_deactivate.is_active is False:
|
||||
logger.warning("{} is already deactivated".format(user_to_deactivate.email))
|
||||
|
||||
user_to_deactivate.is_active = False
|
||||
db_session.add(user_to_deactivate)
|
||||
db_session.commit()
|
||||
deactivate_user(user_to_deactivate, db_session)
|
||||
|
||||
|
||||
@router.delete("/manage/admin/delete-user")
|
||||
@@ -456,7 +461,7 @@ async def delete_user(
|
||||
|
||||
|
||||
@router.patch("/manage/admin/activate-user")
|
||||
def activate_user(
|
||||
def activate_user_api(
|
||||
user_email: UserByEmail,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -470,9 +475,7 @@ def activate_user(
|
||||
if user_to_activate.is_active is True:
|
||||
logger.warning("{} is already activated".format(user_to_activate.email))
|
||||
|
||||
user_to_activate.is_active = True
|
||||
db_session.add(user_to_activate)
|
||||
db_session.commit()
|
||||
activate_user(user_to_activate, db_session)
|
||||
|
||||
|
||||
@router.get("/manage/admin/valid-domains")
|
||||
@@ -577,23 +580,12 @@ def get_current_token_creation(
|
||||
) -> datetime | None:
|
||||
if user is None:
|
||||
return None
|
||||
try:
|
||||
result = db_session.execute(
|
||||
select(AccessToken)
|
||||
.where(AccessToken.user_id == user.id) # type: ignore
|
||||
.order_by(desc(Column("created_at")))
|
||||
.limit(1)
|
||||
)
|
||||
access_token = result.scalar_one_or_none()
|
||||
|
||||
if access_token:
|
||||
return access_token.created_at
|
||||
else:
|
||||
logger.error("No AccessToken found for user")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching AccessToken: {e}")
|
||||
access_token = get_latest_access_token_for_user(user.id, db_session)
|
||||
if access_token:
|
||||
return access_token.created_at
|
||||
else:
|
||||
logger.error("No AccessToken found for user")
|
||||
return None
|
||||
|
||||
|
||||
@@ -675,7 +667,7 @@ def verify_user_logged_in(
|
||||
|
||||
|
||||
@router.patch("/temperature-override-enabled")
|
||||
def update_user_temperature_override_enabled(
|
||||
def update_user_temperature_override_enabled_api(
|
||||
temperature_override_enabled: bool,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -692,12 +684,9 @@ def update_user_temperature_override_enabled(
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(temperature_override_enabled=temperature_override_enabled)
|
||||
update_user_temperature_override_enabled(
|
||||
user.id, temperature_override_enabled, db_session
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class ChosenDefaultModelRequest(BaseModel):
|
||||
@@ -705,7 +694,7 @@ class ChosenDefaultModelRequest(BaseModel):
|
||||
|
||||
|
||||
@router.patch("/shortcut-enabled")
|
||||
def update_user_shortcut_enabled(
|
||||
def update_user_shortcut_enabled_api(
|
||||
shortcut_enabled: bool,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -720,16 +709,11 @@ def update_user_shortcut_enabled(
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(shortcut_enabled=shortcut_enabled)
|
||||
)
|
||||
db_session.commit()
|
||||
update_user_shortcut_enabled(user.id, shortcut_enabled, db_session)
|
||||
|
||||
|
||||
@router.patch("/auto-scroll")
|
||||
def update_user_auto_scroll(
|
||||
def update_user_auto_scroll_api(
|
||||
request: AutoScrollRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -744,16 +728,11 @@ def update_user_auto_scroll(
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(auto_scroll=request.auto_scroll)
|
||||
)
|
||||
db_session.commit()
|
||||
update_user_auto_scroll(user.id, request.auto_scroll, db_session)
|
||||
|
||||
|
||||
@router.patch("/user/default-model")
|
||||
def update_user_default_model(
|
||||
def update_user_default_model_api(
|
||||
request: ChosenDefaultModelRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -768,12 +747,7 @@ def update_user_default_model(
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(default_model=request.default_model)
|
||||
)
|
||||
db_session.commit()
|
||||
update_user_default_model(user.id, request.default_model, db_session)
|
||||
|
||||
|
||||
class ReorderPinnedAssistantsRequest(BaseModel):
|
||||
@@ -781,7 +755,7 @@ class ReorderPinnedAssistantsRequest(BaseModel):
|
||||
|
||||
|
||||
@router.patch("/user/pinned-assistants")
|
||||
def update_user_pinned_assistants(
|
||||
def update_user_pinned_assistants_api(
|
||||
request: ReorderPinnedAssistantsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -798,12 +772,7 @@ def update_user_pinned_assistants(
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(pinned_assistants=ordered_assistant_ids)
|
||||
)
|
||||
db_session.commit()
|
||||
update_user_pinned_assistants(user.id, ordered_assistant_ids, db_session)
|
||||
|
||||
|
||||
class ChosenAssistantsRequest(BaseModel):
|
||||
@@ -833,7 +802,7 @@ def update_assistant_visibility(
|
||||
|
||||
|
||||
@router.patch("/user/assistant-list/update/{assistant_id}")
|
||||
def update_user_assistant_visibility(
|
||||
def update_user_assistant_visibility_api(
|
||||
assistant_id: int,
|
||||
show: bool,
|
||||
user: User | None = Depends(current_user),
|
||||
@@ -861,13 +830,62 @@ def update_user_assistant_visibility(
|
||||
)
|
||||
if updated_preferences.chosen_assistants is not None:
|
||||
updated_preferences.chosen_assistants.append(assistant_id)
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(
|
||||
hidden_assistants=updated_preferences.hidden_assistants,
|
||||
visible_assistants=updated_preferences.visible_assistants,
|
||||
chosen_assistants=updated_preferences.chosen_assistants,
|
||||
)
|
||||
update_user_assistant_visibility(
|
||||
user.id,
|
||||
updated_preferences.hidden_assistants,
|
||||
updated_preferences.visible_assistants,
|
||||
updated_preferences.chosen_assistants,
|
||||
db_session,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/user/assistant/preferences")
|
||||
def get_user_assistant_preferences(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserSpecificAssistantPreferences | None:
|
||||
"""Fetch all assistant preferences for the user."""
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
return no_auth_user.preferences.assistant_specific_configs
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
assistant_specific_configs = get_all_user_assistant_specific_configs(
|
||||
user.id, db_session
|
||||
)
|
||||
return {
|
||||
config.assistant_id: UserSpecificAssistantPreference(
|
||||
disabled_tool_ids=config.disabled_tool_ids
|
||||
)
|
||||
for config in assistant_specific_configs
|
||||
}
|
||||
|
||||
|
||||
@router.patch("/user/assistant/{assistant_id}/preferences")
|
||||
def update_assistant_preferences_for_user_api(
|
||||
assistant_id: int,
|
||||
new_assistant_preference: UserSpecificAssistantPreference,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
if no_auth_user.preferences.assistant_specific_configs is None:
|
||||
no_auth_user.preferences.assistant_specific_configs = {}
|
||||
|
||||
no_auth_user.preferences.assistant_specific_configs[assistant_id] = (
|
||||
new_assistant_preference
|
||||
)
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
update_assistant_preferences(
|
||||
assistant_id, user.id, new_assistant_preference, db_session
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -144,6 +144,12 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
# List of allowed tool IDs to restrict tool usage. If not provided, all tools available to the persona will be used.
|
||||
allowed_tool_ids: list[int] | None = None
|
||||
|
||||
# List of tool IDs we MUST use.
|
||||
forced_tool_ids: list[int] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
|
||||
@@ -183,6 +183,7 @@ def construct_tools(
|
||||
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
@@ -193,6 +194,10 @@ def construct_tools(
|
||||
user_oauth_token = user.oauth_accounts[0].access_token
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
# If allowed_tool_ids is specified, skip tools not in the allowed list
|
||||
if allowed_tool_ids is not None and db_tool_model.id not in allowed_tool_ids:
|
||||
continue
|
||||
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(
|
||||
db_tool_model.in_code_tool_id, db_session
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { MinimalPersonaSnapshot } from "../../admin/assistants/interfaces";
|
||||
|
||||
export function ChatIntro({
|
||||
selectedPersona,
|
||||
}: {
|
||||
selectedPersona: MinimalPersonaSnapshot;
|
||||
}) {
|
||||
return (
|
||||
<div data-testid="chat-intro" className="flex flex-col items-center gap-6">
|
||||
<div className="relative flex flex-col gap-y-4 w-fit mx-auto justify-center">
|
||||
<div className="absolute z-10 items-center flex -left-12 top-1/2 -translate-y-1/2">
|
||||
<AssistantIcon size={36} assistant={selectedPersona} />
|
||||
</div>
|
||||
|
||||
<div className="text-4xl text-text font-normal text-center">
|
||||
{selectedPersona.name}
|
||||
</div>
|
||||
</div>
|
||||
<div className="self-stretch text-center text-text-darker text-xl font-[350] leading-normal">
|
||||
{selectedPersona.description}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -28,8 +28,6 @@ import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader";
|
||||
import { FeedbackModal } from "./modal/FeedbackModal";
|
||||
import { ShareChatSessionModal } from "./modal/ShareChatSessionModal";
|
||||
import { FiArrowDown } from "react-icons/fi";
|
||||
import { ChatIntro } from "./ChatIntro";
|
||||
import { StarterMessages } from "../../../components/assistants/StarterMessage";
|
||||
import { OnyxDocument, MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
@@ -51,6 +49,7 @@ import TextView from "@/components/chat/TextView";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { useSendMessageToParent } from "@/lib/extension/utils";
|
||||
import { SUBMIT_MESSAGE_TYPES } from "@/lib/extension/constants";
|
||||
import { Logo } from "@/components/logo/Logo";
|
||||
|
||||
import { getSourceMetadata } from "@/lib/sources";
|
||||
import { UserSettingsModal } from "./modal/UserSettingsModal";
|
||||
@@ -78,7 +77,6 @@ import {
|
||||
import {
|
||||
useCurrentChatState,
|
||||
useSubmittedMessage,
|
||||
useAgenticGenerating,
|
||||
useLoadingError,
|
||||
useIsReady,
|
||||
useIsFetching,
|
||||
@@ -92,6 +90,8 @@ import {
|
||||
import { AIMessage } from "../message/messageComponents/AIMessage";
|
||||
import { FederatedOAuthModal } from "@/components/chat/FederatedOAuthModal";
|
||||
import { HumanMessage } from "../message/HumanMessage";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { StarterMessageDisplay } from "./starterMessages/StarterMessageDisplay";
|
||||
|
||||
export function ChatPage({
|
||||
toggle,
|
||||
@@ -115,7 +115,6 @@ export function ChatPage({
|
||||
llmProviders,
|
||||
folders,
|
||||
shouldShowWelcomeModal,
|
||||
proSearchToggled,
|
||||
refreshChatSessions,
|
||||
} = useChatContext();
|
||||
|
||||
@@ -468,7 +467,6 @@ export function ChatPage({
|
||||
const currentChatState = useCurrentChatState();
|
||||
const chatSessionId = useChatSessionStore((state) => state.currentSessionId);
|
||||
const submittedMessage = useSubmittedMessage();
|
||||
const agenticGenerating = useAgenticGenerating();
|
||||
const loadingError = useLoadingError();
|
||||
const uncaughtError = useUncaughtError();
|
||||
const isReady = useIsReady();
|
||||
@@ -531,8 +529,7 @@ export function ChatPage({
|
||||
onSubmit,
|
||||
});
|
||||
|
||||
const autoScrollEnabled =
|
||||
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
|
||||
const autoScrollEnabled = user?.preferences?.auto_scroll ?? false;
|
||||
|
||||
useScrollonStream({
|
||||
chatState: currentChatState,
|
||||
@@ -710,14 +707,6 @@ export function ChatPage({
|
||||
redirect("/auth/login");
|
||||
}
|
||||
|
||||
if (noAssistants)
|
||||
return (
|
||||
<>
|
||||
<HealthCheckBanner />
|
||||
<NoAssistantModal isAdmin={isAdmin} />
|
||||
</>
|
||||
);
|
||||
|
||||
const clearSelectedDocuments = () => {
|
||||
setSelectedDocuments([]);
|
||||
clearSelectedItems();
|
||||
@@ -731,6 +720,39 @@ export function ChatPage({
|
||||
);
|
||||
};
|
||||
|
||||
// Determine whether to show the centered input (no messages yet)
|
||||
const showCenteredInput = useMemo(() => {
|
||||
return (
|
||||
messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages &&
|
||||
!loadingError &&
|
||||
!submittedMessage
|
||||
);
|
||||
}, [
|
||||
messageHistory.length,
|
||||
isFetchingChatMessages,
|
||||
loadingError,
|
||||
submittedMessage,
|
||||
]);
|
||||
|
||||
const inputContainerClasses = useMemo(() => {
|
||||
return `absolute pointer-events-none z-10 w-full transition-transform duration-200 ease-out ${
|
||||
showCenteredInput
|
||||
? "top-1/2 left-0 -translate-y-1/2"
|
||||
: "bottom-0 left-0 translate-y-0"
|
||||
}`;
|
||||
}, [showCenteredInput]);
|
||||
|
||||
// handle error case where no assistants are available
|
||||
if (noAssistants) {
|
||||
return (
|
||||
<>
|
||||
<HealthCheckBanner />
|
||||
<NoAssistantModal isAdmin={isAdmin} />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<HealthCheckBanner />
|
||||
@@ -1060,24 +1082,8 @@ export function ChatPage({
|
||||
{messageHistory.length === 0 &&
|
||||
!isFetchingChatMessages &&
|
||||
!loadingError &&
|
||||
!submittedMessage && (
|
||||
<div className="h-full w-[95%] mx-auto flex flex-col justify-center items-center">
|
||||
<ChatIntro selectedPersona={liveAssistant} />
|
||||
|
||||
<StarterMessages
|
||||
currentPersona={liveAssistant}
|
||||
onSubmit={(messageOverride) =>
|
||||
onSubmit({
|
||||
message: messageOverride,
|
||||
selectedFiles: selectedFiles,
|
||||
selectedFolders: selectedFolders,
|
||||
currentMessageFiles: currentMessageFiles,
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
})
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
!submittedMessage &&
|
||||
null}
|
||||
<div
|
||||
style={{ overflowAnchor: "none" }}
|
||||
key={chatSessionId}
|
||||
@@ -1238,11 +1244,8 @@ export function ChatPage({
|
||||
<div ref={endDivRef} />
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
ref={inputRef}
|
||||
className="absolute pointer-events-none bottom-0 z-10 w-full"
|
||||
>
|
||||
{aboveHorizon && (
|
||||
<div ref={inputRef} className={inputContainerClasses}>
|
||||
{!showCenteredInput && aboveHorizon && (
|
||||
<div className="mx-auto w-fit !pointer-events-none flex sticky justify-center">
|
||||
<button
|
||||
onClick={() => clientScrollToBottom()}
|
||||
@@ -1253,16 +1256,31 @@ export function ChatPage({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="pointer-events-auto w-[95%] mx-auto relative mb-8">
|
||||
<div className="pointer-events-auto w-[95%] mx-auto relative mb-8 text-text-600">
|
||||
{showCenteredInput && (
|
||||
<div className="flex justify-center mb-6 transition-opacity duration-300">
|
||||
{/*
|
||||
TODO: decide which way to go
|
||||
<AssistantIcon
|
||||
assistant={liveAssistant}
|
||||
size="large"
|
||||
/>
|
||||
<div className="ml-4 flex justify-center items-center text-center text-3xl font-bold">
|
||||
{liveAssistant.name}
|
||||
</div> */}
|
||||
<Logo
|
||||
height={48}
|
||||
width={48}
|
||||
className="mx-auto"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<ChatInputBar
|
||||
deepResearchEnabled={deepResearchEnabled}
|
||||
setDeepResearchEnabled={() =>
|
||||
toggleDeepResearch()
|
||||
}
|
||||
toggleDocumentSidebar={toggleDocumentSidebar}
|
||||
availableSources={sources}
|
||||
availableDocumentSets={documentSets}
|
||||
availableTags={tags}
|
||||
filterManager={filterManager}
|
||||
llmManager={llmManager}
|
||||
removeDocs={() => {
|
||||
@@ -1292,10 +1310,33 @@ export function ChatPage({
|
||||
selectedAssistant={
|
||||
selectedAssistant || liveAssistant
|
||||
}
|
||||
setFiles={setCurrentMessageFiles}
|
||||
handleFileUpload={handleMessageSpecificFileUpload}
|
||||
textAreaRef={textAreaRef}
|
||||
/>
|
||||
|
||||
{liveAssistant.starter_messages &&
|
||||
liveAssistant.starter_messages.length > 0 &&
|
||||
messageHistory.length === 0 &&
|
||||
showCenteredInput && (
|
||||
<div className="mt-6">
|
||||
<StarterMessageDisplay
|
||||
starterMessages={
|
||||
liveAssistant.starter_messages
|
||||
}
|
||||
onSelectStarterMessage={(message) => {
|
||||
onSubmit({
|
||||
message: message,
|
||||
selectedFiles: selectedFiles,
|
||||
selectedFolders: selectedFolders,
|
||||
currentMessageFiles:
|
||||
currentMessageFiles,
|
||||
useAgentSearch: deepResearchEnabled,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{enterpriseSettings &&
|
||||
enterpriseSettings.custom_lower_disclaimer_content && (
|
||||
<div className="mobile:hidden mt-4 flex items-center justify-center relative w-[95%] mx-auto">
|
||||
|
||||
333
web/src/app/chat/components/input/ActionManagement.tsx
Normal file
333
web/src/app/chat/components/input/ActionManagement.tsx
Normal file
@@ -0,0 +1,333 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
SlidersVerticalIcon,
|
||||
SearchIcon,
|
||||
DisableIcon,
|
||||
IconProps,
|
||||
MoreActionsIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import React, { useState } from "react";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { useAssistantsContext } from "@/components/context/AssistantsContext";
|
||||
import Link from "next/link";
|
||||
import { getIconForAction } from "../../services/actionUtils";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
interface ActionItemProps {
|
||||
Icon: (iconProps: IconProps) => JSX.Element;
|
||||
label: string;
|
||||
disabled: boolean;
|
||||
isForced: boolean;
|
||||
onToggle: () => void;
|
||||
onForceToggle: () => void;
|
||||
}
|
||||
|
||||
export function ActionItem({
|
||||
Icon,
|
||||
label,
|
||||
disabled,
|
||||
isForced,
|
||||
onToggle,
|
||||
onForceToggle,
|
||||
}: ActionItemProps) {
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
group
|
||||
flex
|
||||
items-center
|
||||
justify-between
|
||||
px-2
|
||||
cursor-pointer
|
||||
hover:bg-background-100
|
||||
rounded-lg
|
||||
py-2
|
||||
mx-1
|
||||
${isForced ? "bg-accent-100 hover:bg-accent-200" : ""}
|
||||
`}
|
||||
onClick={() => {
|
||||
// If disabled, un-disable the tool
|
||||
if (onToggle && disabled) {
|
||||
onToggle();
|
||||
}
|
||||
|
||||
onForceToggle();
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className={`flex items-center gap-2 flex-1 ${
|
||||
disabled ? "opacity-50" : ""
|
||||
} ${isForced && "text-blue-500"}`}
|
||||
>
|
||||
<Icon size={16} className="text-text-500" />
|
||||
<span
|
||||
className={`text-sm font-medium select-none ${
|
||||
disabled ? "line-through" : ""
|
||||
}`}
|
||||
>
|
||||
{label}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
className={`
|
||||
flex
|
||||
items-center
|
||||
gap-2
|
||||
transition-opacity
|
||||
duration-200
|
||||
${disabled ? "opacity-100" : "opacity-0 group-hover:opacity-100"}
|
||||
`}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onToggle();
|
||||
}}
|
||||
>
|
||||
<DisableIcon
|
||||
className={`transition-colors cursor-pointer ${
|
||||
disabled
|
||||
? "text-text-900 hover:text-text-500"
|
||||
: "text-text-500 hover:text-text-900"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ToolItem({
|
||||
tool,
|
||||
isToggled,
|
||||
isForced,
|
||||
onToggle,
|
||||
onForceToggle,
|
||||
}: {
|
||||
tool: ToolSnapshot;
|
||||
isToggled: boolean;
|
||||
isForced: boolean;
|
||||
onToggle: () => void;
|
||||
onForceToggle: () => void;
|
||||
}) {
|
||||
const Icon = getIconForAction(tool);
|
||||
return (
|
||||
<ActionItem
|
||||
Icon={Icon}
|
||||
label={tool.display_name || tool.name}
|
||||
disabled={!isToggled}
|
||||
isForced={isForced}
|
||||
onToggle={onToggle}
|
||||
onForceToggle={onForceToggle}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
interface ActionToggleProps {
|
||||
selectedAssistant: MinimalPersonaSnapshot;
|
||||
}
|
||||
|
||||
export function ActionToggle({ selectedAssistant }: ActionToggleProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
|
||||
// Get the assistant preference for this assistant
|
||||
const {
|
||||
assistantPreferences,
|
||||
setSpecificAssistantPreferences,
|
||||
forcedToolIds,
|
||||
setForcedToolIds,
|
||||
} = useAssistantsContext();
|
||||
|
||||
const { isAdmin, isCurator } = useUser();
|
||||
|
||||
const assistantPreference = assistantPreferences?.[selectedAssistant.id];
|
||||
const disabledToolIds = assistantPreference?.disabled_tool_ids || [];
|
||||
const toggleToolForCurrentAssistant = (toolId: number) => {
|
||||
const disabled = disabledToolIds.includes(toolId);
|
||||
setSpecificAssistantPreferences(selectedAssistant.id, {
|
||||
disabled_tool_ids: disabled
|
||||
? disabledToolIds.filter((id) => id !== toolId)
|
||||
: [...disabledToolIds, toolId],
|
||||
});
|
||||
};
|
||||
|
||||
const toggleForcedTool = (toolId: number) => {
|
||||
if (forcedToolIds.includes(toolId)) {
|
||||
setForcedToolIds(forcedToolIds.filter((id) => id !== toolId));
|
||||
} else {
|
||||
setForcedToolIds([...forcedToolIds, toolId]);
|
||||
}
|
||||
};
|
||||
|
||||
// Filter tools based on search term
|
||||
const filteredTools = selectedAssistant.tools.filter((tool) => {
|
||||
if (!searchTerm) return true;
|
||||
const searchLower = searchTerm.toLowerCase();
|
||||
return (
|
||||
tool.display_name?.toLowerCase().includes(searchLower) ||
|
||||
tool.name.toLowerCase().includes(searchLower) ||
|
||||
tool.description?.toLowerCase().includes(searchLower)
|
||||
);
|
||||
});
|
||||
|
||||
// If no tools are available, don't render the component
|
||||
if (selectedAssistant.tools.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Popover
|
||||
open={open}
|
||||
onOpenChange={(newOpen) => {
|
||||
setOpen(newOpen);
|
||||
// Clear search when closing
|
||||
if (!newOpen) {
|
||||
setSearchTerm("");
|
||||
}
|
||||
}}
|
||||
>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className="
|
||||
relative
|
||||
cursor-pointer
|
||||
flex
|
||||
items-center
|
||||
group
|
||||
rounded-lg
|
||||
text-input-text
|
||||
hover:bg-background-chat-hover
|
||||
hover:text-neutral-900
|
||||
dark:hover:text-neutral-50
|
||||
py-1.5
|
||||
px-2
|
||||
flex-none
|
||||
whitespace-nowrap
|
||||
overflow-hidden
|
||||
focus:outline-none
|
||||
"
|
||||
data-testid="action-popover-trigger"
|
||||
title={open ? undefined : "Configure actions"}
|
||||
>
|
||||
<SlidersVerticalIcon size={16} className="my-auto flex-none" />
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
side="top"
|
||||
align="start"
|
||||
className="
|
||||
w-[244px]
|
||||
max-h-[300px]
|
||||
text-text-600
|
||||
text-sm
|
||||
p-0
|
||||
bg-background
|
||||
border
|
||||
border-border
|
||||
rounded-xl
|
||||
shadow-xl
|
||||
overflow-hidden
|
||||
flex
|
||||
flex-col
|
||||
"
|
||||
>
|
||||
{/* Search Input */}
|
||||
<div className="pt-1 mx-1">
|
||||
<div className="relative">
|
||||
<SearchIcon
|
||||
size={16}
|
||||
className="absolute left-3 top-1/2 transform -translate-y-1/2 text-text-400"
|
||||
/>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search Menu"
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="
|
||||
w-full
|
||||
pl-9
|
||||
pr-3
|
||||
py-2
|
||||
bg-background-50
|
||||
rounded-lg
|
||||
text-sm
|
||||
outline-none
|
||||
text-text-700
|
||||
placeholder:text-text-400
|
||||
"
|
||||
autoFocus
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Options */}
|
||||
<div className="pt-2 flex-1 overflow-y-auto mx-1 pb-2">
|
||||
{filteredTools.length === 0 ? (
|
||||
<div className="text-center py-1 text-text-400">
|
||||
No matching actions found
|
||||
</div>
|
||||
) : (
|
||||
filteredTools.map((tool) => (
|
||||
<ToolItem
|
||||
key={tool.id}
|
||||
tool={tool}
|
||||
isToggled={!disabledToolIds.includes(tool.id)}
|
||||
isForced={forcedToolIds.includes(tool.id)}
|
||||
onToggle={() => toggleToolForCurrentAssistant(tool.id)}
|
||||
onForceToggle={() => {
|
||||
toggleForcedTool(tool.id);
|
||||
setOpen(false);
|
||||
}}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="border-b border-border mx-3.5" />
|
||||
|
||||
{/* More Connectors & Actions. Only show if user is admin or curator, since
|
||||
they are the only ones who can manage actions. */}
|
||||
{(isAdmin || isCurator) && (
|
||||
<Link href="/admin/actions">
|
||||
<button
|
||||
className="
|
||||
w-full
|
||||
flex
|
||||
items-center
|
||||
justify-between
|
||||
text-text-400
|
||||
text-sm
|
||||
mt-2.5
|
||||
"
|
||||
>
|
||||
<div
|
||||
className="
|
||||
mx-2
|
||||
mb-2
|
||||
px-2
|
||||
py-1.5
|
||||
flex
|
||||
items-center
|
||||
hover:bg-background-100
|
||||
hover:text-text-500
|
||||
transition-colors
|
||||
rounded-lg
|
||||
w-full
|
||||
"
|
||||
>
|
||||
<MoreActionsIcon className="text-text-500" />
|
||||
<div className="ml-2">More Actions</div>
|
||||
</div>
|
||||
</button>
|
||||
</Link>
|
||||
)}
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
interface AgenticToggleProps {
|
||||
proSearchEnabled: boolean;
|
||||
setProSearchEnabled: (enabled: boolean) => void;
|
||||
}
|
||||
|
||||
const ProSearchIcon = () => (
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path
|
||||
d="M21 21L16.65 16.65M19 11C19 15.4183 15.4183 19 11 19C6.58172 19 3 15.4183 3 11C3 6.58172 6.58172 3 11 3C15.4183 3 19 6.58172 19 11Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M11 8V14M8 11H14"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export function AgenticToggle({
|
||||
proSearchEnabled,
|
||||
setProSearchEnabled,
|
||||
}: AgenticToggleProps) {
|
||||
const handleToggle = () => {
|
||||
setProSearchEnabled(!proSearchEnabled);
|
||||
};
|
||||
|
||||
return (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<button
|
||||
className={`ml-auto py-1.5
|
||||
rounded-lg
|
||||
group
|
||||
px-2 inline-flex items-center`}
|
||||
onClick={handleToggle}
|
||||
role="switch"
|
||||
aria-checked={proSearchEnabled}
|
||||
>
|
||||
<div
|
||||
className={`
|
||||
${
|
||||
proSearchEnabled
|
||||
? "border-background-200 group-hover:border-[#000] dark:group-hover:border-neutral-300"
|
||||
: "border-background-200 group-hover:border-[#000] dark:group-hover:border-neutral-300"
|
||||
}
|
||||
relative inline-flex h-[16px] w-8 items-center rounded-full transition-colors focus:outline-none border animate transition-all duration-200 border-background-200 group-hover:border-[1px] `}
|
||||
>
|
||||
<span
|
||||
className={`${
|
||||
proSearchEnabled
|
||||
? "bg-agent translate-x-4 scale-75"
|
||||
: "bg-background-600 group-hover:bg-background-950 translate-x-0.5 scale-75"
|
||||
} inline-block h-[12px] w-[12px] group-hover:scale-90 transform rounded-full transition-transform duration-200 ease-in-out`}
|
||||
/>
|
||||
</div>
|
||||
<span
|
||||
className={`ml-2 text-sm font-[550] flex items-center ${
|
||||
proSearchEnabled ? "text-agent" : "text-text-dark"
|
||||
}`}
|
||||
>
|
||||
Agent
|
||||
</span>
|
||||
</button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent
|
||||
side="top"
|
||||
width="w-72"
|
||||
className="p-4 bg-white rounded-lg shadow-lg border border-background-200 dark:border-neutral-900"
|
||||
>
|
||||
<div className="flex items-center space-x-2 mb-3">
|
||||
<h3 className="text-sm font-semibold text-neutral-900">
|
||||
Agent Search
|
||||
</h3>
|
||||
</div>
|
||||
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">
|
||||
Use AI agents to break down questions and run deep iterative
|
||||
research through promising pathways. Gives more thorough and
|
||||
accurate responses but takes slightly longer.
|
||||
</p>
|
||||
<ul className="text-xs text-text-600 dark:text-neutral-700 list-disc list-inside">
|
||||
<li>Improved accuracy of search results</li>
|
||||
<li>Less hallucinations</li>
|
||||
<li>More comprehensive answers</li>
|
||||
</ul>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useContext, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { FiPlusCircle, FiPlus, FiFilter } from "react-icons/fi";
|
||||
import { FiPlus, FiFilter } from "react-icons/fi";
|
||||
import { FiLoader } from "react-icons/fi";
|
||||
import { ChatInputOption } from "./ChatInputOption";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
@@ -17,12 +17,6 @@ import {
|
||||
StopGeneratingIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { OnyxDocument, SourceMetadata } from "@/lib/search/interfaces";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { ChatState } from "@/app/chat/interfaces";
|
||||
import { useAssistantsContext } from "@/components/context/AssistantsContext";
|
||||
import { CalendarIcon, TagIcon, XIcon, FolderIcon } from "lucide-react";
|
||||
@@ -33,11 +27,12 @@ import { getFormattedDateRangeString } from "@/lib/dateUtils";
|
||||
import { truncateString } from "@/lib/utils";
|
||||
import { buildImgUrl } from "@/app/chat/components/files/images/utils";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { AgenticToggle } from "./AgenticToggle";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { useDocumentsContext } from "@/app/chat/my-documents/DocumentsContext";
|
||||
import { UnconfiguredLlmProviderText } from "@/components/chat/UnconfiguredLlmProviderText";
|
||||
import { DeepResearchToggle } from "./DeepResearchToggle";
|
||||
import { ActionToggle } from "./ActionManagement";
|
||||
import { SelectedTool } from "./SelectedTool";
|
||||
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
@@ -105,16 +100,13 @@ interface ChatInputBarProps {
|
||||
selectedAssistant: MinimalPersonaSnapshot;
|
||||
|
||||
toggleDocumentSidebar: () => void;
|
||||
setFiles: (files: FileDescriptor[]) => void;
|
||||
handleFileUpload: (files: File[]) => void;
|
||||
textAreaRef: React.RefObject<HTMLTextAreaElement>;
|
||||
filterManager: FilterManager;
|
||||
availableSources: SourceMetadata[];
|
||||
availableDocumentSets: DocumentSetSummary[];
|
||||
availableTags: Tag[];
|
||||
retrievalEnabled: boolean;
|
||||
deepResearchEnabled: boolean;
|
||||
setDeepResearchEnabled: (deepResearchEnabled: boolean) => void;
|
||||
placeholder?: string;
|
||||
}
|
||||
|
||||
export function ChatInputBar({
|
||||
@@ -134,15 +126,12 @@ export function ChatInputBar({
|
||||
// assistants
|
||||
selectedAssistant,
|
||||
|
||||
setFiles,
|
||||
handleFileUpload,
|
||||
textAreaRef,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
availableTags,
|
||||
llmManager,
|
||||
deepResearchEnabled,
|
||||
setDeepResearchEnabled,
|
||||
placeholder,
|
||||
}: ChatInputBarProps) {
|
||||
const { user } = useUser();
|
||||
const {
|
||||
@@ -154,6 +143,8 @@ export function ChatInputBar({
|
||||
setCurrentMessageFiles,
|
||||
} = useDocumentsContext();
|
||||
|
||||
const { forcedToolIds, setForcedToolIds } = useAssistantsContext();
|
||||
|
||||
// Create a Set of IDs from currentMessageFiles for efficient lookup
|
||||
// Assuming FileDescriptor.id corresponds conceptually to FileResponse.file_id or FileResponse.id
|
||||
const currentMessageFileIds = useMemo(
|
||||
@@ -191,8 +182,6 @@ export function ChatInputBar({
|
||||
}
|
||||
};
|
||||
|
||||
const { finalAssistants: assistantOptions } = useAssistantsContext();
|
||||
|
||||
const { llmProviders, inputPrompts } = useChatContext();
|
||||
|
||||
const suggestionsRef = useRef<HTMLDivElement | null>(null);
|
||||
@@ -486,7 +475,10 @@ export function ChatInputBar({
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
role="textarea"
|
||||
aria-multiline
|
||||
placeholder={`How can ${selectedAssistant.name} help you today`}
|
||||
placeholder={
|
||||
placeholder ||
|
||||
`How can ${selectedAssistant.name} help you today`
|
||||
}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
if (
|
||||
@@ -659,7 +651,7 @@ export function ChatInputBar({
|
||||
)}
|
||||
|
||||
<div className="flex pr-4 pb-2 justify-between bg-input-background items-center w-full ">
|
||||
<div className="space-x-1 flex px-4 ">
|
||||
<div className="space-x-1 flex px-4 ">
|
||||
<ChatInputOption
|
||||
flexPriority="stiff"
|
||||
Icon={FileUploadIcon}
|
||||
@@ -669,23 +661,8 @@ export function ChatInputBar({
|
||||
tooltipContent={"Upload files and attach user files"}
|
||||
/>
|
||||
|
||||
{retrievalEnabled && (
|
||||
<FilterPopup
|
||||
availableSources={availableSources}
|
||||
availableDocumentSets={
|
||||
selectedAssistant.document_sets &&
|
||||
selectedAssistant.document_sets.length > 0
|
||||
? selectedAssistant.document_sets
|
||||
: availableDocumentSets
|
||||
}
|
||||
availableTags={availableTags}
|
||||
filterManager={filterManager}
|
||||
trigger={{
|
||||
name: "Filters",
|
||||
Icon: FiFilter,
|
||||
tooltipContent: "Filter your search",
|
||||
}}
|
||||
/>
|
||||
{selectedAssistant.tools.length > 0 && (
|
||||
<ActionToggle selectedAssistant={selectedAssistant} />
|
||||
)}
|
||||
|
||||
{retrievalEnabled &&
|
||||
@@ -695,7 +672,32 @@ export function ChatInputBar({
|
||||
setDeepResearchEnabled={setDeepResearchEnabled}
|
||||
/>
|
||||
)}
|
||||
|
||||
{forcedToolIds.length > 0 && (
|
||||
<div className="pl-1 flex items-center gap-2 text-blue-500">
|
||||
{forcedToolIds.map((toolId) => {
|
||||
const tool = selectedAssistant.tools.find(
|
||||
(tool) => tool.id === toolId
|
||||
);
|
||||
if (!tool) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<SelectedTool
|
||||
key={toolId}
|
||||
tool={tool}
|
||||
onClick={() => {
|
||||
setForcedToolIds((prev) =>
|
||||
prev.filter((id) => id !== toolId)
|
||||
);
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center my-auto gap-x-2">
|
||||
<LLMPopover
|
||||
llmProviders={llmProviders}
|
||||
|
||||
25
web/src/app/chat/components/input/SelectedTool.tsx
Normal file
25
web/src/app/chat/components/input/SelectedTool.tsx
Normal file
@@ -0,0 +1,25 @@
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { getIconForAction } from "../../services/actionUtils";
|
||||
import { XIcon } from "@/components/icons/icons";
|
||||
|
||||
export function SelectedTool({
|
||||
tool,
|
||||
onClick,
|
||||
}: {
|
||||
tool: ToolSnapshot;
|
||||
onClick: () => void;
|
||||
}) {
|
||||
const Icon = getIconForAction(tool);
|
||||
return (
|
||||
<div
|
||||
className="flex items-center cursor-pointer hover:bg-background-100 rounded-lg p-1"
|
||||
onClick={onClick}
|
||||
>
|
||||
<Icon size={16} />
|
||||
<span className="text-sm font-medium select-none ml-1.5 mr-1">
|
||||
{tool.display_name}
|
||||
</span>
|
||||
<XIcon size={12} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
import { StarterMessage } from "@/app/admin/assistants/interfaces";
|
||||
|
||||
export function StarterMessageDisplay({
|
||||
starterMessages,
|
||||
onSelectStarterMessage,
|
||||
}: {
|
||||
starterMessages: StarterMessage[];
|
||||
onSelectStarterMessage: (message: string) => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-col gap-2 w-full max-w-searchbar-max mx-auto">
|
||||
{starterMessages.map((starterMessage) => (
|
||||
<div
|
||||
key={starterMessage.name}
|
||||
onClick={() => onSelectStarterMessage(starterMessage.message)}
|
||||
className="
|
||||
text-left
|
||||
text-text-500
|
||||
text-sm
|
||||
mx-7
|
||||
px-2
|
||||
py-2
|
||||
hover:bg-background-100
|
||||
rounded-lg
|
||||
cursor-pointer
|
||||
overflow-hidden
|
||||
text-ellipsis
|
||||
whitespace-nowrap
|
||||
"
|
||||
>
|
||||
{starterMessage.name}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
57
web/src/app/chat/hooks/useAssistantPreferences.ts
Normal file
57
web/src/app/chat/hooks/useAssistantPreferences.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import {
|
||||
UserSpecificAssistantPreference,
|
||||
UserSpecificAssistantPreferences,
|
||||
} from "@/lib/types";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
const ASSISTANT_PREFERENCES_URL = "/api/user/assistant/preferences";
|
||||
|
||||
const buildUpdateAssistantPreferenceUrl = (assistantId: number) =>
|
||||
`/api/user/assistant/${assistantId}/preferences`;
|
||||
|
||||
export function useAssistantPreferences() {
|
||||
const [assistantPreferences, _setAssistantPreferences] =
|
||||
useState<UserSpecificAssistantPreferences | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchAssistantPreferences = async () => {
|
||||
const response = await fetch(ASSISTANT_PREFERENCES_URL);
|
||||
const data = await response.json();
|
||||
_setAssistantPreferences(data);
|
||||
};
|
||||
fetchAssistantPreferences();
|
||||
}, []);
|
||||
|
||||
const setSpecificAssistantPreferences = async (
|
||||
assistantId: number,
|
||||
newAssistantPreference: UserSpecificAssistantPreference
|
||||
) => {
|
||||
_setAssistantPreferences({
|
||||
...assistantPreferences,
|
||||
[assistantId]: newAssistantPreference,
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
buildUpdateAssistantPreferenceUrl(assistantId),
|
||||
{
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(newAssistantPreference),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
console.error(
|
||||
`Failed to update assistant preferences: ${response.status}`
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating assistant preferences:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return { assistantPreferences, setSpecificAssistantPreferences };
|
||||
}
|
||||
@@ -22,7 +22,6 @@ import { FilterManager, LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import {
|
||||
BackendMessage,
|
||||
ChatFileType,
|
||||
ChatSessionSharedStatus,
|
||||
CitationMap,
|
||||
FileChatDisplay,
|
||||
FileDescriptor,
|
||||
@@ -31,7 +30,6 @@ import {
|
||||
RegenerationState,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
SubQuestionDetail,
|
||||
ToolCallMetadata,
|
||||
UserKnowledgeFilePacket,
|
||||
} from "../interfaces";
|
||||
@@ -73,6 +71,7 @@ import {
|
||||
MessageStart,
|
||||
PacketType,
|
||||
} from "../services/streamingModels";
|
||||
import { useAssistantsContext } from "@/components/context/AssistantsContext";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -122,6 +121,7 @@ export function useChatController({
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const { refreshChatSessions, llmProviders } = useChatContext();
|
||||
const { assistantPreferences, forcedToolIds } = useAssistantsContext();
|
||||
|
||||
// Use selectors to access only the specific fields we need
|
||||
const currentSessionId = useChatSessionStore(
|
||||
@@ -157,12 +157,6 @@ export function useChatController({
|
||||
const setAbortController = useChatSessionStore(
|
||||
(state) => state.setAbortController
|
||||
);
|
||||
const setAgenticGenerating = useChatSessionStore(
|
||||
(state) => state.setAgenticGenerating
|
||||
);
|
||||
const setIsFetchingChatMessages = useChatSessionStore(
|
||||
(state) => state.setIsFetchingChatMessages
|
||||
);
|
||||
const setIsReady = useChatSessionStore((state) => state.setIsReady);
|
||||
|
||||
// Use custom hooks for accessing store data
|
||||
@@ -170,21 +164,13 @@ export function useChatController({
|
||||
const currentMessageHistory = useCurrentMessageHistory();
|
||||
const currentChatState = useCurrentChatState();
|
||||
|
||||
const {
|
||||
selectedFiles,
|
||||
selectedFolders,
|
||||
addSelectedFile,
|
||||
uploadFile,
|
||||
setCurrentMessageFiles,
|
||||
clearSelectedItems,
|
||||
} = useDocumentsContext();
|
||||
const { selectedFiles, selectedFolders, uploadFile, setCurrentMessageFiles } =
|
||||
useDocumentsContext();
|
||||
|
||||
const navigatingAway = useRef(false);
|
||||
|
||||
// Local state that doesn't need to be in the store
|
||||
const [maxTokens, setMaxTokens] = useState<number>(4096);
|
||||
const [chatSessionSharedStatus, setChatSessionSharedStatus] =
|
||||
useState<ChatSessionSharedStatus>(ChatSessionSharedStatus.Private);
|
||||
|
||||
// Sync store state changes
|
||||
useEffect(() => {
|
||||
@@ -536,6 +522,9 @@ export function useChatController({
|
||||
const lastSuccessfulMessageId = getLastSuccessfulMessageId(
|
||||
currentMessageTreeLocal
|
||||
);
|
||||
const disabledToolIds = liveAssistant
|
||||
? assistantPreferences?.[liveAssistant?.id]?.disabled_tool_ids
|
||||
: undefined;
|
||||
|
||||
const stack = new CurrentMessageFIFO();
|
||||
updateCurrentMessageFIFO(stack, {
|
||||
@@ -581,6 +570,13 @@ export function useChatController({
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
useAgentSearch,
|
||||
enabledToolIds:
|
||||
disabledToolIds && liveAssistant
|
||||
? liveAssistant.tools
|
||||
.filter((tool) => !disabledToolIds?.includes(tool.id))
|
||||
.map((tool) => tool.id)
|
||||
: undefined,
|
||||
forcedToolIds: forcedToolIds,
|
||||
});
|
||||
|
||||
const delay = (ms: number) => {
|
||||
@@ -639,7 +635,6 @@ export function useChatController({
|
||||
) {
|
||||
setUncaughtError(frozenSessionId, (packet as StreamingError).error);
|
||||
updateChatStateAction(frozenSessionId, "input");
|
||||
setAgenticGenerating(frozenSessionId, false);
|
||||
updateSubmittedMessage(getCurrentSessionId(), "");
|
||||
|
||||
throw new Error((packet as StreamingError).error);
|
||||
@@ -767,7 +762,6 @@ export function useChatController({
|
||||
currentMessageTreeLocal = newMessageDetails.messageTree;
|
||||
}
|
||||
|
||||
setAgenticGenerating(frozenSessionId, false);
|
||||
resetRegenerationState(frozenSessionId);
|
||||
|
||||
updateChatStateAction(frozenSessionId, "input");
|
||||
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
useCurrentMessageHistory,
|
||||
} from "../stores/useChatSessionStore";
|
||||
import { getCitations } from "../services/packetUtils";
|
||||
import { useAssistantsContext } from "@/components/context/AssistantsContext";
|
||||
|
||||
interface UseChatSessionControllerProps {
|
||||
existingChatSessionId: string | null;
|
||||
@@ -111,6 +112,7 @@ export function useChatSessionController({
|
||||
state.sessions.get(state.currentSessionId || "")?.chatState || "input"
|
||||
);
|
||||
const currentChatHistory = useCurrentMessageHistory();
|
||||
const { setForcedToolIds } = useAssistantsContext();
|
||||
|
||||
// Fetch chat messages for the chat session
|
||||
useEffect(() => {
|
||||
@@ -142,6 +144,9 @@ export function useChatSessionController({
|
||||
if (existingChatSessionId) {
|
||||
updateHasPerformedInitialScroll(existingChatSessionId, false);
|
||||
}
|
||||
|
||||
// Clear forced tool ids if and only if we're switching to a new chat session
|
||||
setForcedToolIds([]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
64
web/src/app/chat/services/actionUtils.ts
Normal file
64
web/src/app/chat/services/actionUtils.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
import {
|
||||
CpuIcon,
|
||||
DatabaseIcon,
|
||||
IconProps,
|
||||
UsersIcon,
|
||||
AppSearchIcon,
|
||||
GlobeIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
|
||||
// Helper functions to identify specific tools
|
||||
const isSearchTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "SearchTool" ||
|
||||
tool.name === "run_search" ||
|
||||
tool.display_name?.toLowerCase().includes("search tool")
|
||||
);
|
||||
};
|
||||
|
||||
const isWebSearchTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "InternetSearchTool" ||
|
||||
tool.display_name?.toLowerCase().includes("internet search")
|
||||
);
|
||||
};
|
||||
|
||||
const isImageGenerationTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "ImageGenerationTool" ||
|
||||
tool.display_name?.toLowerCase().includes("image generation")
|
||||
);
|
||||
};
|
||||
|
||||
const isKnowledgeGraphTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "KnowledgeGraphTool" ||
|
||||
tool.display_name?.toLowerCase().includes("knowledge graph")
|
||||
);
|
||||
};
|
||||
|
||||
const isOktaProfileTool = (tool: ToolSnapshot): boolean => {
|
||||
return (
|
||||
tool.in_code_tool_id === "OktaProfileTool" ||
|
||||
tool.display_name?.toLowerCase().includes("okta profile")
|
||||
);
|
||||
};
|
||||
|
||||
export function getIconForAction(
|
||||
action: ToolSnapshot
|
||||
): (iconProps: IconProps) => JSX.Element {
|
||||
if (isSearchTool(action)) {
|
||||
return AppSearchIcon;
|
||||
} else if (isWebSearchTool(action)) {
|
||||
return GlobeIcon;
|
||||
} else if (isImageGenerationTool(action)) {
|
||||
return DatabaseIcon;
|
||||
} else if (isKnowledgeGraphTool(action)) {
|
||||
return DatabaseIcon;
|
||||
} else if (isOktaProfileTool(action)) {
|
||||
return UsersIcon;
|
||||
} else {
|
||||
return CpuIcon;
|
||||
}
|
||||
}
|
||||
@@ -176,6 +176,8 @@ export interface SendMessageParams {
|
||||
userFileIds?: number[];
|
||||
userFolderIds?: number[];
|
||||
useAgentSearch?: boolean;
|
||||
enabledToolIds?: number[];
|
||||
forcedToolIds?: number[];
|
||||
}
|
||||
|
||||
export async function* sendMessage({
|
||||
@@ -198,6 +200,8 @@ export async function* sendMessage({
|
||||
alternateAssistantId,
|
||||
signal,
|
||||
useAgentSearch,
|
||||
enabledToolIds,
|
||||
forcedToolIds,
|
||||
}: SendMessageParams): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
@@ -238,6 +242,8 @@ export async function* sendMessage({
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
use_agentic_search: useAgentSearch ?? false,
|
||||
allowed_tool_ids: enabledToolIds,
|
||||
forced_tool_ids: forcedToolIds,
|
||||
});
|
||||
|
||||
const response = await fetch(`/api/chat/send-message`, {
|
||||
|
||||
@@ -29,7 +29,6 @@ interface ChatSessionData {
|
||||
|
||||
// Session-specific state (previously global)
|
||||
isFetchingChatMessages: boolean;
|
||||
agenticGenerating: boolean;
|
||||
uncaughtError: string | null;
|
||||
loadingError: string | null;
|
||||
isReady: boolean;
|
||||
@@ -112,7 +111,6 @@ interface ChatSessionStore {
|
||||
|
||||
// Actions - Session-specific State (previously global)
|
||||
setIsFetchingChatMessages: (sessionId: string, fetching: boolean) => void;
|
||||
setAgenticGenerating: (sessionId: string, generating: boolean) => void;
|
||||
setUncaughtError: (sessionId: string, error: string | null) => void;
|
||||
setLoadingError: (sessionId: string, error: string | null) => void;
|
||||
setIsReady: (sessionId: string, ready: boolean) => void;
|
||||
@@ -150,7 +148,6 @@ const createInitialSessionData = (
|
||||
|
||||
// Session-specific state defaults
|
||||
isFetchingChatMessages: false,
|
||||
agenticGenerating: false,
|
||||
uncaughtError: null,
|
||||
loadingError: null,
|
||||
isReady: true,
|
||||
@@ -387,10 +384,6 @@ export const useChatSessionStore = create<ChatSessionStore>()((set, get) => ({
|
||||
get().updateSessionData(sessionId, { isFetchingChatMessages });
|
||||
},
|
||||
|
||||
setAgenticGenerating: (sessionId: string, agenticGenerating: boolean) => {
|
||||
get().updateSessionData(sessionId, { agenticGenerating });
|
||||
},
|
||||
|
||||
setUncaughtError: (sessionId: string, uncaughtError: string | null) => {
|
||||
get().updateSessionData(sessionId, { uncaughtError });
|
||||
},
|
||||
@@ -569,15 +562,6 @@ export const useAbortControllers = () => {
|
||||
};
|
||||
|
||||
// Session-specific state hooks (previously global)
|
||||
export const useAgenticGenerating = () =>
|
||||
useChatSessionStore((state) => {
|
||||
const { currentSessionId, sessions } = state;
|
||||
const currentSession = currentSessionId
|
||||
? sessions.get(currentSessionId)
|
||||
: null;
|
||||
return currentSession?.agenticGenerating || false;
|
||||
});
|
||||
|
||||
export const useIsFetching = () =>
|
||||
useChatSessionStore((state) => {
|
||||
const { currentSessionId, sessions } = state;
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import { useContext } from "react";
|
||||
import { MinimalPersonaSnapshot } from "../../app/admin/assistants/interfaces";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
|
||||
export function StarterMessages({
|
||||
currentPersona,
|
||||
onSubmit,
|
||||
}: {
|
||||
currentPersona: MinimalPersonaSnapshot;
|
||||
onSubmit: (messageOverride: string) => void;
|
||||
}) {
|
||||
const settings = useContext(SettingsContext);
|
||||
const isMobile = settings?.isMobile;
|
||||
return (
|
||||
<div
|
||||
key={-4}
|
||||
className={`
|
||||
very-short:hidden
|
||||
mx-auto
|
||||
w-full
|
||||
${
|
||||
isMobile
|
||||
? "gap-x-2 w-2/3 justify-between"
|
||||
: "justify-center max-w-[750px] items-start"
|
||||
}
|
||||
flex
|
||||
mt-6
|
||||
`}
|
||||
>
|
||||
{currentPersona?.starter_messages &&
|
||||
currentPersona.starter_messages.length > 0 && (
|
||||
<>
|
||||
{currentPersona.starter_messages
|
||||
.slice(0, isMobile ? 2 : 4)
|
||||
.map((starterMessage, i) => (
|
||||
<div
|
||||
key={i}
|
||||
className={`${
|
||||
isMobile ? "w-1/2" : "w-1/4"
|
||||
} flex justify-center`}
|
||||
>
|
||||
<button
|
||||
onClick={() => onSubmit(starterMessage.message)}
|
||||
className={`
|
||||
relative flex ${!isMobile ? "w-40" : "w-full max-w-52"}
|
||||
shadow
|
||||
border-background-300/60
|
||||
flex-col gap-2 rounded-md
|
||||
text-input-text hover:text-text
|
||||
border
|
||||
dark:bg-transparent
|
||||
dark:border-neutral-700
|
||||
dark:hover:bg-background-150
|
||||
font-normal
|
||||
px-3 py-2
|
||||
text-start align-to text-wrap
|
||||
text-[15px] shadow-xs transition
|
||||
enabled:hover:bg-background-dark/75
|
||||
disabled:cursor-not-allowed
|
||||
overflow-hidden
|
||||
break-words
|
||||
truncate
|
||||
text-ellipsis
|
||||
`}
|
||||
style={{ height: "5.6rem" }}
|
||||
>
|
||||
<div className="overflow-hidden text-ellipsis line-clamp-3 pr-1 pb-1">
|
||||
{starterMessage.name}
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -16,6 +16,11 @@ import {
|
||||
filterAssistants,
|
||||
} from "@/lib/assistants/utils";
|
||||
import { useUser } from "../user/UserProvider";
|
||||
import {
|
||||
UserSpecificAssistantPreference,
|
||||
UserSpecificAssistantPreferences,
|
||||
} from "@/lib/types";
|
||||
import { useAssistantPreferences } from "@/app/chat/hooks/useAssistantPreferences";
|
||||
|
||||
interface AssistantsContextProps {
|
||||
assistants: MinimalPersonaSnapshot[];
|
||||
@@ -25,8 +30,18 @@ interface AssistantsContextProps {
|
||||
ownedButHiddenAssistants: MinimalPersonaSnapshot[];
|
||||
refreshAssistants: () => Promise<void>;
|
||||
isImageGenerationAvailable: boolean;
|
||||
|
||||
pinnedAssistants: MinimalPersonaSnapshot[];
|
||||
setPinnedAssistants: Dispatch<SetStateAction<MinimalPersonaSnapshot[]>>;
|
||||
|
||||
assistantPreferences: UserSpecificAssistantPreferences | null;
|
||||
setSpecificAssistantPreferences: (
|
||||
assistantId: number,
|
||||
assistantPreferences: UserSpecificAssistantPreference
|
||||
) => void;
|
||||
|
||||
forcedToolIds: number[];
|
||||
setForcedToolIds: Dispatch<SetStateAction<number[]>>;
|
||||
}
|
||||
|
||||
const AssistantsContext = createContext<AssistantsContextProps | undefined>(
|
||||
@@ -43,6 +58,9 @@ export const AssistantsProvider: React.FC<{
|
||||
initialAssistants || []
|
||||
);
|
||||
const { user } = useUser();
|
||||
const { assistantPreferences, setSpecificAssistantPreferences } =
|
||||
useAssistantPreferences();
|
||||
const [forcedToolIds, setForcedToolIds] = useState<number[]>([]);
|
||||
|
||||
const [pinnedAssistants, setPinnedAssistants] = useState<
|
||||
MinimalPersonaSnapshot[]
|
||||
@@ -149,6 +167,10 @@ export const AssistantsProvider: React.FC<{
|
||||
isImageGenerationAvailable,
|
||||
setPinnedAssistants,
|
||||
pinnedAssistants,
|
||||
assistantPreferences,
|
||||
setSpecificAssistantPreferences,
|
||||
forcedToolIds,
|
||||
setForcedToolIds,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
||||
@@ -394,6 +394,121 @@ export const PlusCircleIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const AppSearchIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox={`0 0 ${size} ${size}`}
|
||||
fill="none"
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<path
|
||||
d="M1.00261 7.5H2.5M1 4H3.25M1.00261 11H3.25M15 13L12.682 10.682M12.682 10.682C13.4963 9.86764 14 8.74264 14 7.5C14 5.01472 11.9853 3 9.49999 3C7.01472 3 5 5.01472 5 7.5C5 9.98528 7.01472 12 9.49999 12C10.7426 12 11.8676 11.4963 12.682 10.682Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const DisableIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox={`0 0 ${size} ${size}`}
|
||||
fill="none"
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<g clipPath="url(#clip0_295_7943)">
|
||||
<path
|
||||
d="M3.28659 3.28665L12.7133 12.7133M14.6666 7.99998C14.6666 11.6819 11.6818 14.6666 7.99992 14.6666C4.31802 14.6666 1.33325 11.6819 1.33325 7.99998C1.33325 4.31808 4.31802 1.33331 7.99992 1.33331C11.6818 1.33331 14.6666 4.31808 14.6666 7.99998Z"
|
||||
stroke="currentColor"
|
||||
strokeOpacity="0.4"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_295_7943">
|
||||
<rect width="16" height="16" fill="white" />
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const MoreActionsIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox={`0 0 ${size} ${size}`}
|
||||
fill="none"
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<path
|
||||
d="M3.06 6.24449L5.12 4.12225L3.06 2.00001M11.5501 14L14 11.5501M14 11.5501L11.5501 9.10017M14 11.5501H9.75552M4.12224 9.09889L6.24448 10.3242V12.7747L4.12224 14L2 12.7747V10.3242L4.12224 9.09889ZM14 4.12225C14 5.29433 13.0498 6.24449 11.8778 6.24449C10.7057 6.24449 9.75552 5.29433 9.75552 4.12225C9.75552 2.95017 10.7057 2.00001 11.8778 2.00001C13.0498 2.00001 14 2.95017 14 4.12225Z"
|
||||
stroke="currentColor"
|
||||
strokeOpacity="0.4"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const SlidersVerticalIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
viewBox={`0 0 ${size} ${size}`}
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<g clipPath="url(#clip0_16_2627)">
|
||||
<path
|
||||
d="M2.66666 14V9.33333M2.66666 6.66667V2M7.99999 14V8M7.99999 5.33333V2M13.3333 14V10.6667M13.3333 8V2M0.666656 9.33333H4.66666M5.99999 5.33333H9.99999M11.3333 10.6667H15.3333"
|
||||
stroke="currentColor"
|
||||
strokeOpacity="0.8"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_16_2627">
|
||||
<rect width="16" height="16" fill="white" />
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const PlugIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
||||
@@ -3,6 +3,15 @@ import { Credential } from "./connectors/credentials";
|
||||
import { Connector } from "./connectors/connectors";
|
||||
import { ConnectorCredentialPairStatus } from "@/app/admin/connector/[ccPairId]/types";
|
||||
|
||||
export interface UserSpecificAssistantPreference {
|
||||
disabled_tool_ids?: number[];
|
||||
}
|
||||
|
||||
export type UserSpecificAssistantPreferences = Record<
|
||||
number,
|
||||
UserSpecificAssistantPreference
|
||||
>;
|
||||
|
||||
interface UserPreferences {
|
||||
chosen_assistants: number[] | null;
|
||||
visible_assistants: number[];
|
||||
|
||||
Reference in New Issue
Block a user