Compare commits

..

15 Commits

Author SHA1 Message Date
pablonyx
3a01014212 k 2025-02-27 18:04:19 -08:00
pablonyx
45b6c5bfed address comments 2025-02-27 15:37:53 -08:00
pablonyx
04e980a0e8 fix build 2025-02-27 15:25:58 -08:00
pablonyx
3c2480ef21 k 2025-02-27 15:23:51 -08:00
pablonyx
ebb57d6216 k 2025-02-27 15:23:51 -08:00
Richard Kuo (Danswer)
4c230f92ea trivy test 2025-02-27 15:05:03 -08:00
Richard Kuo (Danswer)
07d75b04d1 enable trivy scan 2025-02-27 14:22:44 -08:00
evan-danswer
a8d10750c1 fix propagation of is_agentic (#4150) 2025-02-27 11:56:51 -08:00
pablonyx
85e3ed57f1 Order chat sessions by time updated, not created (#4143)
* order chat sessions by time updated, not created

* quick update

* k
2025-02-27 17:35:42 +00:00
pablonyx
e10cc8ccdb Multi tenant user google auth fix (#4145) 2025-02-27 10:35:38 -08:00
pablonyx
7018bc974b Better looking errors (#4050)
* add error handling

* fix

* k
2025-02-27 04:58:25 +00:00
pablonyx
9c9075d71d Minor improvements to provisioning (#4109)
* quick fix

* k

* nit
2025-02-27 04:57:31 +00:00
pablonyx
338e084062 Improved tenant handling for slack bot (#4099) 2025-02-27 04:06:26 +00:00
pablonyx
2f64031f5c Improved tenant handling for slack bot1 (#4104) 2025-02-27 03:40:50 +00:00
pablonyx
abb74f2eaa Improved chat search (#4137)
* functional + fast

* k

* adapt

* k

* nit

* k

* k

* fix typing

* k
2025-02-27 02:27:45 +00:00
36 changed files with 855 additions and 433 deletions

View File

@@ -53,22 +53,26 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
if: always()
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
if: always()
uses: aquasecurity/trivy-action@0.29.0
with:
scan-type: fs
scan-ref: .
scanners: license
format: table
severity: HIGH,CRITICAL
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3

View File

@@ -0,0 +1,84 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> BillingInformation:
def fetch_billing_information(
tenant_id: str,
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = BillingInformation(**response.json())
return billing_info
response_data = response.json()
# Check if the response indicates no subscription
if (
isinstance(response_data, dict)
and "subscribed" in response_data
and not response_data["subscribed"]
):
return SubscriptionStatusResponse(**response_data)
# Otherwise, parse as BillingInformation
return BillingInformation(**response_data)
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:

View File

@@ -200,25 +200,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
def configure_default_api_keys(db_session: Session) -> None:
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
@@ -227,6 +208,7 @@ def configure_default_api_keys(db_session: Session) -> None:
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@@ -238,6 +220,26 @@ def configure_default_api_keys(db_session: Session) -> None:
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,

View File

@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User
user: User | None = None
try:
# Attempt to get user by OAuth account
@@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = cast(User, await self.user_db.get_by_email(account_email))
user = await self.user_db.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Make sure user is not None before adding OAuth account
if user is not None:
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
else:
# This shouldn't happen since get_by_email would raise UserNotExists
# but adding as a safeguard
raise exceptions.UserNotExists()
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
@@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
# Add OAuth account only if user creation was successful
if user is not None:
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
raise HTTPException(
status_code=500, detail="Failed to create user account"
)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# User exists, update OAuth account if needed
if user is not None: # Add explicit check
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# Ensure user is not None before proceeding
if user is None:
raise HTTPException(
status_code=500, detail="Failed to authenticate or create user"
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled

View File

@@ -16,7 +16,7 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.setup import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -962,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
is_agentic=chat_message.is_agentic,
error=chat_message.error,
)

View File

@@ -3,14 +3,13 @@ from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import column
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import literal
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import union_all
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -26,127 +25,87 @@ def search_chat_sessions(
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
Fast full-text search on ChatSession + ChatMessage using tsvectors.
Returns a tuple of (chat_sessions, has_more)
If no query is provided, returns the most recent chat sessions.
Otherwise, searches both chat messages and session descriptions.
Returns a tuple of (sessions, has_more) where has_more indicates if
there are additional results beyond the requested page.
"""
offset = (page - 1) * page_size
offset_val = (page - 1) * page_size
# If no search query, we use standard SQLAlchemy pagination
# If no query, just return the most recent sessions
if not query or not query.strip():
stmt = select(ChatSession)
if user_id:
stmt = (
select(ChatSession)
.order_by(desc(ChatSession.time_created))
.offset(offset_val)
.limit(page_size + 1)
)
if user_id is not None:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
# Apply pagination
stmt = stmt.offset(offset).limit(page_size + 1)
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
sessions = result.scalars().all()
has_more = len(chat_sessions) > page_size
has_more = len(sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
sessions = sessions[:page_size]
return list(chat_sessions), has_more
return list(sessions), has_more
words = query.lower().strip().split()
# Otherwise, proceed with full-text search
query = query.strip()
# Message mach subquery
message_matches = []
for word in words:
word_like = f"%{word}%"
message_match: Select = (
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
.where(func.lower(ChatMessage.message).like(word_like))
)
if user_id:
message_match = message_match.where(ChatSession.user_id == user_id)
message_matches.append(message_match)
if message_matches:
message_matches_query = union_all(*message_matches).alias("message_matches")
else:
return [], False
# Description matches
description_match: Select = select(
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
if user_id:
description_match = description_match.where(ChatSession.user_id == user_id)
base_conditions = []
if user_id is not None:
base_conditions.append(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
description_match = description_match.where(ChatSession.deleted.is_(False))
base_conditions.append(ChatSession.deleted.is_(False))
# Combine all match sources
combined_matches = union_all(
message_matches_query.select(), description_match
).alias("combined_matches")
message_tsv: ColumnClause = column("message_tsv")
description_tsv: ColumnClause = column("description_tsv")
# Use CTE to group and get max rank
session_ranks = (
select(
combined_matches.c.chat_session_id,
func.max(combined_matches.c.search_rank).label("rank"),
)
.group_by(combined_matches.c.chat_session_id)
.alias("session_ranks")
ts_query = func.plainto_tsquery("english", query)
description_session_ids = (
select(ChatSession.id)
.where(*base_conditions)
.where(description_tsv.op("@@")(ts_query))
)
# Get ranked sessions with pagination
ranked_query = (
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
.offset(offset)
message_session_ids = (
select(ChatMessage.chat_session_id)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(*base_conditions)
.where(message_tsv.op("@@")(ts_query))
)
combined_ids = description_session_ids.union(message_session_ids).alias(
"combined_ids"
)
final_stmt = (
select(ChatSession)
.join(combined_ids, ChatSession.id == combined_ids.c.id)
.order_by(desc(ChatSession.time_created))
.distinct()
.offset(offset_val)
.limit(page_size + 1)
.options(joinedload(ChatSession.persona))
)
result = ranked_query.all()
session_objs = db_session.execute(final_stmt).scalars().all()
# Extract session IDs and ranks
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
session_ids = list(session_ids_with_ranks.keys())
if not session_ids:
return [], False
# Now, let's query the actual ChatSession objects using the IDs
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
# Full objects with eager loading
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
# Sort based on above ranking
chat_sessions = sorted(
chat_sessions,
key=lambda session: (
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
session.time_created.timestamp() * -1, # Then by time (newest first)
),
)
has_more = len(chat_sessions) > page_size
has_more = len(session_objs) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
session_objs = session_objs[:page_size]
return chat_sessions, has_more
return list(session_objs), has_more

View File

@@ -25,6 +25,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Text

View File

@@ -23,7 +23,7 @@ from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_session_by_message_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChannelConfig
from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
@@ -410,12 +410,11 @@ def _build_qa_response_blocks(
def _build_continue_in_web_ui_block(
tenant_id: str,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
@@ -482,7 +481,6 @@ def build_follow_up_resolved_blocks(
def build_slack_response_blocks(
answer: ChatOnyxBotResponse,
tenant_id: str,
message_info: SlackMessageInfo,
channel_conf: ChannelConfig | None,
use_citations: bool,
@@ -517,7 +515,6 @@ def build_slack_response_blocks(
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)

View File

@@ -11,7 +11,7 @@ from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
@@ -114,7 +114,7 @@ def handle_generate_answer_button(
thread_ts=thread_ts,
)
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
@@ -136,7 +136,6 @@ def handle_generate_answer_button(
slack_channel_config=slack_channel_config,
receiver_ids=None,
client=client.web_client,
tenant_id=client.tenant_id,
channel=channel_id,
logger=logger,
feedback_reminder_id=None,
@@ -151,11 +150,10 @@ def handle_slack_feedback(
user_id_to_post_confirmation: str,
channel_id_to_post_confirmation: str,
thread_ts_to_post_confirmation: str,
tenant_id: str,
) -> None:
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
@@ -246,7 +244,7 @@ def handle_followup_button(
tag_ids: list[str] = []
group_ids: list[str] = []
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)

View File

@@ -5,7 +5,7 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.onyxbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.users import add_slack_user_if_not_exists
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
@@ -109,7 +109,6 @@ def handle_message(
slack_channel_config: SlackChannelConfig,
client: WebClient,
feedback_reminder_id: str | None,
tenant_id: str,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -135,9 +134,7 @@ def handle_message(
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
)
slack_usage_report(action=action, sender_id=sender_id, client=client)
document_set_names: list[str] | None = None
persona = slack_channel_config.persona if slack_channel_config else None
@@ -218,7 +215,7 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
if message_info.email:
add_slack_user_if_not_exists(db_session, message_info.email)
@@ -244,6 +241,5 @@ def handle_message(
channel=channel,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
tenant_id=tenant_id,
)
return issue_with_regular_answer

View File

@@ -24,7 +24,6 @@ from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import RetrievalDetails
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
@@ -72,7 +71,6 @@ def handle_regular_answer(
channel: str,
logger: OnyxLoggingAdapter,
feedback_reminder_id: str | None,
tenant_id: str,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
@@ -87,7 +85,7 @@ def handle_regular_answer(
user = None
if message_info.is_bot_dm:
if message_info.email:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
@@ -96,7 +94,7 @@ def handle_regular_answer(
# This way slack flow always has a persona
persona = slack_channel_config.persona
if not persona:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
document_set_names = [
document_set.name for document_set in persona.document_sets
@@ -157,7 +155,7 @@ def handle_regular_answer(
def _get_slack_answer(
new_message_request: CreateChatMessageRequest, onyx_user: User | None
) -> ChatOnyxBotResponse:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
packets = stream_chat_message_objects(
new_msg_req=new_message_request,
user=onyx_user,
@@ -197,7 +195,7 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
answer_request = prepare_chat_message_request(
message_text=user_message.message,
user=user,
@@ -361,7 +359,6 @@ def handle_regular_answer(
return True
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
channel_conf=channel_conf,

View File

@@ -37,6 +37,7 @@ from onyx.context.search.retrieval.search_runner import (
download_nltk_data,
)
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackBot
from onyx.db.search_settings import get_current_search_settings
@@ -92,6 +93,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -347,7 +349,7 @@ class SlackbotHandler:
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
@@ -586,7 +588,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
channel_name, _ = get_channel_name_from_id(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
@@ -680,7 +682,6 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
user_id_to_post_confirmation=user_id,
channel_id_to_post_confirmation=channel_id,
thread_ts_to_post_confirmation=thread_ts,
tenant_id=client.tenant_id,
)
query_event_id, _, _ = decompose_action_id(feedback_id)
@@ -796,8 +797,9 @@ def process_message(
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
tenant_id = get_current_tenant_id()
logger.debug(
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
)
# Throw out requests that can't or shouldn't be handled
@@ -810,50 +812,39 @@ def process_message(
client=client.web_client, channel_id=channel
)
token: Token[str | None] | None = None
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
try:
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
with get_session_with_current_tenant() as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
)
follow_up = bool(
slack_channel_config.channel_config
and slack_channel_config.channel_config.get("follow_up_tags")
is not None
)
follow_up = bool(
slack_channel_config.channel_config
and slack_channel_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
failed = handle_message(
message_info=details,
slack_channel_config=slack_channel_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
failed = handle_message(
message_info=details,
slack_channel_config=slack_channel_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender_id,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender_id,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:

View File

@@ -4,6 +4,8 @@ import re
import string
import time
import uuid
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from typing import cast
@@ -30,7 +32,7 @@ from onyx.configs.onyxbot_configs import (
)
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
@@ -43,6 +45,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.text_processing import replace_whitespaces_w_space
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -569,9 +572,7 @@ def read_slack_thread(
return thread_messages
def slack_usage_report(
action: str, sender_id: str | None, client: WebClient, tenant_id: str
) -> None:
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
if DISABLE_TELEMETRY:
return
@@ -583,14 +584,13 @@ def slack_usage_report(
logger.warning("Unable to find sender email")
if sender_email is not None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
optional_telemetry(
record_type=RecordType.USAGE,
data={"action": action},
user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User",
tenant_id=tenant_id,
)
@@ -665,5 +665,28 @@ def get_feedback_visibility() -> FeedbackVisibility:
class TenantSocketModeClient(SocketModeClient):
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id
self._tenant_id = tenant_id
self.slack_bot_id = slack_bot_id
@contextmanager
def _set_tenant_context(self) -> Generator[None, None, None]:
token = None
try:
if self._tenant_id:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self._tenant_id)
yield
finally:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def enqueue_message(self, message: str) -> None:
with self._set_tenant_context():
super().enqueue_message(message)
def process_message(self) -> None:
with self._set_tenant_context():
super().process_message()
def run_message_listeners(self, message: dict, raw_message: str) -> None:
with self._set_tenant_context():
super().run_message_listeners(message, raw_message)

View File

@@ -646,7 +646,6 @@ def associate_credential_to_connector(
)
return response
except ValidationError as e:
# If validation fails, delete the connector and commit the changes
# Ensures we don't leave invalid connectors in the database
@@ -660,10 +659,14 @@ def associate_credential_to_connector(
)
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")
delete_connector(db_session, connector_id)
db_session.commit()
raise HTTPException(status_code=400, detail="Name must be unique")
except Exception as e:
logger.exception(f"Unexpected error: {e}")
raise HTTPException(status_code=500, detail="Unexpected error")

View File

@@ -343,7 +343,8 @@ def list_bot_configs(
]
MAX_CHANNELS = 200
MAX_SLACK_PAGES = 5
SLACK_API_CHANNELS_PER_PAGE = 100
@router.get(
@@ -355,8 +356,8 @@ def get_all_channels_from_slack_api(
_: User | None = Depends(current_admin_user),
) -> list[SlackChannel]:
"""
Fetches all channels from the Slack API.
If the workspace has 200 or more channels, we raise an error.
Fetches channels the bot is a member of from the Slack API.
Handles pagination with a limit to avoid excessive API calls.
"""
tokens = fetch_slack_bot_tokens(db_session, bot_id)
if not tokens or "bot_token" not in tokens:
@@ -365,28 +366,60 @@ def get_all_channels_from_slack_api(
)
client = WebClient(token=tokens["bot_token"])
all_channels = []
next_cursor = None
current_page = 0
try:
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=MAX_CHANNELS,
)
# Use users_conversations with limited pagination
while current_page < MAX_SLACK_PAGES:
current_page += 1
# Make API call with cursor if we have one
if next_cursor:
response = client.users_conversations(
types="public_channel,private_channel",
exclude_archived=True,
cursor=next_cursor,
limit=SLACK_API_CHANNELS_PER_PAGE,
)
else:
response = client.users_conversations(
types="public_channel,private_channel",
exclude_archived=True,
limit=SLACK_API_CHANNELS_PER_PAGE,
)
# Add channels to our list
if "channels" in response and response["channels"]:
all_channels.extend(response["channels"])
# Check if we need to paginate
if (
"response_metadata" in response
and "next_cursor" in response["response_metadata"]
):
next_cursor = response["response_metadata"]["next_cursor"]
if next_cursor:
if current_page == MAX_SLACK_PAGES:
raise HTTPException(
status_code=400,
detail="Workspace has too many channels to paginate over in this call.",
)
continue
# If we get here, no more pages
break
channels = [
SlackChannel(id=channel["id"], name=channel["name"])
for channel in response["channels"]
for channel in all_channels
]
if len(channels) == MAX_CHANNELS:
raise HTTPException(
status_code=400,
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
)
return channels
except SlackApiError as e:
# Handle rate limiting or other API errors
raise HTTPException(
status_code=500,
detail=f"Error fetching channels from Slack API: {str(e)}",

View File

@@ -242,6 +242,7 @@ class ChatMessageDetail(BaseModel):
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None = None
is_agentic: bool | None = None
error: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore

View File

@@ -199,17 +199,17 @@ export function SlackChannelConfigFormFields({
<Badge variant="agent" className="bg-blue-100 text-blue-800">
Default Configuration
</Badge>
<p className="mt-2 text-sm text-gray-600">
<p className="mt-2 text-sm text-neutral-600">
This default configuration will apply across all Slack channels
the bot is added to in the Slack workspace, as well as direct
messages (DMs), unless disabled.
</p>
<div className="mt-4 p-4 bg-gray-100 rounded-md border border-gray-300">
<div className="mt-4 p-4 bg-neutral-100 rounded-md border border-neutral-300">
<CheckFormField
name="disabled"
label="Disable Default Configuration"
/>
<p className="mt-2 text-sm text-gray-600 italic">
<p className="mt-2 text-sm text-neutral-600 italic">
Warning: Disabling the default configuration means the bot
won&apos;t respond in Slack channels or DMs unless explicitly
configured for them.
@@ -238,20 +238,28 @@ export function SlackChannelConfigFormFields({
/>
</div>
) : (
<Field name="channel_name">
{({ field, form }: { field: any; form: any }) => (
<SearchMultiSelectDropdown
options={channelOptions || []}
onSelect={(selected) => {
form.setFieldValue("channel_name", selected.name);
}}
initialSearchTerm={field.value}
onSearchTermChange={(term) => {
form.setFieldValue("channel_name", term);
}}
/>
)}
</Field>
<>
<Field name="channel_name">
{({ field, form }: { field: any; form: any }) => (
<SearchMultiSelectDropdown
options={channelOptions || []}
onSelect={(selected) => {
form.setFieldValue("channel_name", selected.name);
}}
initialSearchTerm={field.value}
onSearchTermChange={(term) => {
form.setFieldValue("channel_name", term);
}}
/>
)}
</Field>
<p className="mt-2 text-sm dark:text-neutral-400 text-neutral-600">
Note: This list shows public and private channels where the
bot is a member (up to 500 channels). If you don&apos;t see a
channel, make sure the bot is added to that channel in Slack
first, or type the channel name manually.
</p>
</>
)}
</>
)}

View File

@@ -0,0 +1,16 @@
export default function AuthErrorLayout({
children,
}: {
children: React.ReactNode;
}) {
// Log error to console for debugging
console.error(
"Authentication error page was accessed - this should not happen in normal flow"
);
// In a production environment, you might want to send this to your error tracking service
// For example, if using a service like Sentry:
// captureException(new Error("Authentication error page was accessed unexpectedly"));
return <>{children}</>;
}

View File

@@ -4,6 +4,7 @@ import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import { Button } from "@/components/ui/button";
import Link from "next/link";
import { FiLogIn } from "react-icons/fi";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
const Page = () => {
return (
@@ -15,19 +16,21 @@ const Page = () => {
<p className="text-text-700 text-center">
We encountered an issue while attempting to log you in.
</p>
<div className="bg-red-50 border border-red-200 rounded-lg p-4 shadow-sm">
<h3 className="text-red-800 font-semibold mb-2">Possible Issues:</h3>
<div className="bg-red-50 dark:bg-red-950/30 border border-red-200 dark:border-red-800 rounded-lg p-4 shadow-sm">
<h3 className="text-red-800 dark:text-red-400 font-semibold mb-2">
Possible Issues:
</h3>
<ul className="space-y-2">
<li className="flex items-center text-red-700">
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
<li className="flex items-center text-red-700 dark:text-red-400">
<div className="w-2 h-2 bg-red-500 dark:bg-red-400 rounded-full mr-2"></div>
Incorrect or expired login credentials
</li>
<li className="flex items-center text-red-700">
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
<li className="flex items-center text-red-700 dark:text-red-400">
<div className="w-2 h-2 bg-red-500 dark:bg-red-400 rounded-full mr-2"></div>
Temporary authentication system disruption
</li>
<li className="flex items-center text-red-700">
<div className="w-2 h-2 bg-red-500 rounded-full mr-2"></div>
<li className="flex items-center text-red-700 dark:text-red-400">
<div className="w-2 h-2 bg-red-500 dark:bg-red-400 rounded-full mr-2"></div>
Account access restrictions or permissions
</li>
</ul>
@@ -41,6 +44,12 @@ const Page = () => {
<p className="text-sm text-text-500 text-center">
We recommend trying again. If you continue to experience problems,
please reach out to your system administrator for assistance.
{NEXT_PUBLIC_CLOUD_ENABLED && (
<span className="block mt-1 text-blue-600">
A member of our team has been automatically notified about this
issue.
</span>
)}
</p>
</div>
</AuthFlowContainer>

View File

@@ -132,18 +132,12 @@ import {
import { getSourceMetadata } from "@/lib/sources";
import { UserSettingsModal } from "./modal/UserSettingsModal";
import { AlignStartVertical } from "lucide-react";
import { AgenticMessage } from "./message/AgenticMessage";
import AssistantModal from "../assistants/mine/AssistantModal";
import {
OperatingSystem,
useOperatingSystem,
useSidebarShortcut,
} from "@/lib/browserUtilities";
import { Button } from "@/components/ui/button";
import { useSidebarShortcut } from "@/lib/browserUtilities";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { MessageChannel } from "node:worker_threads";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
import { ErrorBanner } from "./message/Resubmit";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -1169,6 +1163,7 @@ export function ChatPage({
navigatingAway.current = false;
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
setUncaughtError(null);
// Mark that we've sent a message for this session in the current page load
markSessionMessageSent(frozenSessionId);
@@ -1319,6 +1314,7 @@ export function ChatPage({
let isStreamingQuestions = true;
let includeAgentic = false;
let secondLevelMessageId: number | null = null;
let isAgentic: boolean = false;
let initialFetchDetails: null | {
user_message_id: number;
@@ -1481,6 +1477,9 @@ export function ChatPage({
second_level_generating = true;
}
}
if (Object.hasOwn(packet, "is_agentic")) {
isAgentic = (packet as any).is_agentic;
}
if (Object.hasOwn(packet, "refined_answer_improvement")) {
isImprovement = (packet as RefinedAnswerImprovement)
@@ -1514,6 +1513,7 @@ export function ChatPage({
);
} else if (Object.hasOwn(packet, "sub_question")) {
updateChatState("toolBuilding", frozenSessionId);
isAgentic = true;
is_generating = true;
sub_questions = constructSubQuestions(
sub_questions,
@@ -1714,6 +1714,7 @@ export function ChatPage({
sub_questions: sub_questions,
second_level_generating: second_level_generating,
agentic_docs: agenticDocs,
is_agentic: isAgentic,
},
...(includeAgentic
? [
@@ -2062,6 +2063,26 @@ export function ChatPage({
const [sharedChatSession, setSharedChatSession] =
useState<ChatSession | null>();
const handleResubmitLastMessage = () => {
// Grab the last user-type message
const lastUserMsg = messageHistory
.slice()
.reverse()
.find((m) => m.type === "user");
if (!lastUserMsg) {
setPopup({
message: "No previously-submitted user message found.",
type: "error",
});
return;
}
// We call onSubmit, passing a `messageOverride`
onSubmit({
messageIdToResend: lastUserMsg.messageId,
messageOverride: lastUserMsg.message,
});
};
const showShareModal = (chatSession: ChatSession) => {
setSharedChatSession(chatSession);
};
@@ -2644,9 +2665,9 @@ export function ChatPage({
: null
}
>
{message.sub_questions &&
message.sub_questions.length > 0 ? (
{message.is_agentic ? (
<AgenticMessage
resubmit={handleResubmitLastMessage}
error={uncaughtError}
isStreamingQuestions={
message.isStreamingQuestions ?? false
@@ -2994,21 +3015,18 @@ export function ChatPage({
currentPersona={liveAssistant}
messageId={message.messageId}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
{message.stackTrace && (
<span
onClick={() =>
setStackTraceModalContent(
message.stackTrace!
)
}
className="ml-2 cursor-pointer underline"
>
Show stack trace.
</span>
)}
</p>
<ErrorBanner
resubmit={handleResubmitLastMessage}
error={message.message}
showStackTrace={
message.stackTrace
? () =>
setStackTraceModalContent(
message.stackTrace!
)
: undefined
}
/>
}
/>
</div>

View File

@@ -15,8 +15,8 @@ export function ChatSearchGroup({
}: ChatSearchGroupProps) {
return (
<div className="mb-4">
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-neutral-800/90 py-2 px-4 px-4">
<div className="text-xs font-medium leading-4 text-neutral-600 dark:text-neutral-400">
{title}
</div>
</div>

View File

@@ -1,6 +1,7 @@
import React from "react";
import { MessageSquare } from "lucide-react";
import { ChatSessionSummary } from "./interfaces";
import { truncateString } from "@/lib/utils";
interface ChatSearchItemProps {
chat: ChatSessionSummary;
@@ -11,12 +12,12 @@ export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
return (
<li>
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
<div className="flex items-center">
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-700">
<div className="flex max-w-full mx-2 items-center">
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm dark:text-neutral-200">
{chat.name || "Untitled Chat"}
<div className="relative max-w-full grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm max-w-full dark:text-neutral-200">
{truncateString(chat.name || "Untitled Chat", 90)}
</div>
</div>
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">

View File

@@ -1,8 +1,8 @@
import React, { useState, useEffect } from "react";
import React, { useState, useEffect, useCallback } from "react";
import { InputPrompt } from "@/app/chat/interfaces";
import { Button } from "@/components/ui/button";
import { TrashIcon, PlusIcon } from "@/components/icons/icons";
import { MoreVertical, CheckIcon, XIcon } from "lucide-react";
import { PlusIcon } from "@/components/icons/icons";
import { MoreVertical, XIcon } from "lucide-react";
import { Textarea } from "@/components/ui/textarea";
import Title from "@/components/ui/title";
import Text from "@/components/ui/text";
@@ -153,114 +153,6 @@ export default function InputPrompts() {
}
};
const PromptCard = ({ prompt }: { prompt: InputPrompt }) => {
const isEditing = editingPromptId === prompt.id;
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
const [localContent, setLocalContent] = useState(prompt.content);
// Sync local edits with any prompt changes from outside
useEffect(() => {
setLocalPrompt(prompt.prompt);
setLocalContent(prompt.content);
}, [prompt, isEditing]);
const handleLocalEdit = (field: "prompt" | "content", value: string) => {
if (field === "prompt") {
setLocalPrompt(value);
} else {
setLocalContent(value);
}
};
const handleSaveLocal = () => {
handleSave(prompt.id, localPrompt, localContent);
};
return (
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
{isEditing ? (
<>
<div className="absolute top-2 right-2">
<Button
variant="ghost"
size="sm"
onClick={() => {
setEditingPromptId(null);
fetchInputPrompts(); // Revert changes from server
}}
>
<XIcon size={14} />
</Button>
</div>
<div className="flex">
<div className="flex-grow mr-4">
<Textarea
value={localPrompt}
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
className="mb-2 resize-none"
placeholder="Prompt"
/>
<Textarea
value={localContent}
onChange={(e) => handleLocalEdit("content", e.target.value)}
className="resize-vertical min-h-[100px]"
placeholder="Content"
/>
</div>
<div className="flex items-end">
<Button onClick={handleSaveLocal}>
{prompt.id ? "Save" : "Create"}
</Button>
</div>
</div>
</>
) : (
<>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className="mb-2 flex gap-x-2 ">
<p className="font-semibold">{prompt.prompt}</p>
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
</div>
</TooltipTrigger>
{isPromptPublic(prompt) && (
<TooltipContent>
<p>This is a built-in prompt and cannot be edited</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<div className="whitespace-pre-wrap">{prompt.content}</div>
<div className="absolute top-2 right-2">
<DropdownMenu>
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
<Button
className="!hover:bg-transparent"
variant="ghost"
size="sm"
>
<MoreVertical size={14} />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent>
{!isPromptPublic(prompt) && (
<DropdownMenuItem onClick={() => handleEdit(prompt.id)}>
Edit
</DropdownMenuItem>
)}
<DropdownMenuItem onClick={() => handleDelete(prompt.id)}>
Delete
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
</>
)}
</div>
);
};
return (
<div className="mx-auto max-w-4xl">
<div className="absolute top-4 left-4">
@@ -272,13 +164,21 @@ export default function InputPrompts() {
<Title>Prompt Shortcuts</Title>
<Text>
Manage and customize prompt shortcuts for your assistants. Use your
prompt shortcuts by starting a new message / in chat.
prompt shortcuts by starting a new message with &quot;/&quot; in
chat.
</Text>
</div>
</div>
{inputPrompts.map((prompt) => (
<PromptCard key={prompt.id} prompt={prompt} />
<PromptCard
key={prompt.id}
prompt={prompt}
onEdit={handleEdit}
onSave={handleSave}
onDelete={handleDelete}
isEditing={editingPromptId === prompt.id}
/>
))}
{isCreatingNew ? (
@@ -315,3 +215,129 @@ export default function InputPrompts() {
</div>
);
}
interface PromptCardProps {
prompt: InputPrompt;
onEdit: (id: number) => void;
onSave: (id: number, prompt: string, content: string) => void;
onDelete: (id: number) => void;
isEditing: boolean;
}
const PromptCard: React.FC<PromptCardProps> = ({
prompt,
onEdit,
onSave,
onDelete,
isEditing,
}) => {
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
const [localContent, setLocalContent] = useState(prompt.content);
useEffect(() => {
setLocalPrompt(prompt.prompt);
setLocalContent(prompt.content);
}, [prompt, isEditing]);
const handleLocalEdit = useCallback(
(field: "prompt" | "content", value: string) => {
if (field === "prompt") {
setLocalPrompt(value);
} else {
setLocalContent(value);
}
},
[]
);
const handleSaveLocal = useCallback(() => {
onSave(prompt.id, localPrompt, localContent);
}, [prompt.id, localPrompt, localContent, onSave]);
const isPromptPublic = useCallback((p: InputPrompt): boolean => {
return p.is_public;
}, []);
return (
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
{isEditing ? (
<>
<div className="absolute top-2 right-2">
<Button
variant="ghost"
size="sm"
onClick={() => {
onEdit(0);
}}
>
<XIcon size={14} />
</Button>
</div>
<div className="flex">
<div className="flex-grow mr-4">
<Textarea
value={localPrompt}
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
className="mb-2 resize-none"
placeholder="Prompt"
/>
<Textarea
value={localContent}
onChange={(e) => handleLocalEdit("content", e.target.value)}
className="resize-vertical min-h-[100px]"
placeholder="Content"
/>
</div>
<div className="flex items-end">
<Button onClick={handleSaveLocal}>
{prompt.id ? "Save" : "Create"}
</Button>
</div>
</div>
</>
) : (
<>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className="mb-2 flex gap-x-2 ">
<p className="font-semibold">{prompt.prompt}</p>
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
</div>
</TooltipTrigger>
{isPromptPublic(prompt) && (
<TooltipContent>
<p>This is a built-in prompt and cannot be edited</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<div className="whitespace-pre-wrap">{prompt.content}</div>
<div className="absolute top-2 right-2">
<DropdownMenu>
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
<Button
className="!hover:bg-transparent"
variant="ghost"
size="sm"
>
<MoreVertical size={14} />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent>
{!isPromptPublic(prompt) && (
<DropdownMenuItem onClick={() => onEdit(prompt.id)}>
Edit
</DropdownMenuItem>
)}
<DropdownMenuItem onClick={() => onDelete(prompt.id)}>
Delete
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
</>
)}
</div>
);
};

View File

@@ -0,0 +1,147 @@
import { SourceChip } from "../input/ChatInputBar";
import { useEffect } from "react";
import { useState } from "react";
import { InputPrompt } from "../interfaces";
import { Button } from "@/components/ui/button";
import { XIcon } from "@/components/icons/icons";
import { Textarea } from "@/components/ui/textarea";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
import { MoreVertical } from "lucide-react";
export const PromptCard = ({
prompt,
editingPromptId,
setEditingPromptId,
handleSave,
handleDelete,
isPromptPublic,
handleEdit,
fetchInputPrompts,
}: {
prompt: InputPrompt;
editingPromptId: number | null;
setEditingPromptId: (id: number | null) => void;
handleSave: (id: number, prompt: string, content: string) => void;
handleDelete: (id: number) => void;
isPromptPublic: (prompt: InputPrompt) => boolean;
handleEdit: (id: number) => void;
fetchInputPrompts: () => void;
}) => {
const isEditing = editingPromptId === prompt.id;
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
const [localContent, setLocalContent] = useState(prompt.content);
// Sync local edits with any prompt changes from outside
useEffect(() => {
setLocalPrompt(prompt.prompt);
setLocalContent(prompt.content);
}, [prompt, isEditing]);
const handleLocalEdit = (field: "prompt" | "content", value: string) => {
if (field === "prompt") {
setLocalPrompt(value);
} else {
setLocalContent(value);
}
};
const handleSaveLocal = () => {
handleSave(prompt.id, localPrompt, localContent);
};
return (
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
{isEditing ? (
<>
<div className="absolute top-2 right-2">
<Button
variant="ghost"
size="sm"
onClick={() => {
setEditingPromptId(null);
fetchInputPrompts(); // Revert changes from server
}}
>
<XIcon size={14} />
</Button>
</div>
<div className="flex">
<div className="flex-grow mr-4">
<Textarea
value={localPrompt}
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
className="mb-2 resize-none"
placeholder="Prompt"
/>
<Textarea
value={localContent}
onChange={(e) => handleLocalEdit("content", e.target.value)}
className="resize-vertical min-h-[100px]"
placeholder="Content"
/>
</div>
<div className="flex items-end">
<Button onClick={handleSaveLocal}>
{prompt.id ? "Save" : "Create"}
</Button>
</div>
</div>
</>
) : (
<>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className="mb-2 flex gap-x-2 ">
<p className="font-semibold">{prompt.prompt}</p>
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
</div>
</TooltipTrigger>
{isPromptPublic(prompt) && (
<TooltipContent>
<p>This is a built-in prompt and cannot be edited</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<div className="whitespace-pre-wrap">{prompt.content}</div>
<div className="absolute top-2 right-2">
<DropdownMenu>
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
<Button
className="!hover:bg-transparent"
variant="ghost"
size="sm"
>
<MoreVertical size={14} />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent>
{!isPromptPublic(prompt) && (
<DropdownMenuItem onClick={() => handleEdit(prompt.id)}>
Edit
</DropdownMenuItem>
)}
<DropdownMenuItem onClick={() => handleDelete(prompt.id)}>
Delete
</DropdownMenuItem>
</DropdownMenuContent>
</DropdownMenu>
</div>
</>
)}
</div>
);
};

View File

@@ -154,7 +154,6 @@ export const SourceChip = ({
gap-x-1
h-6
${onClick ? "cursor-pointer" : ""}
animate-fade-in-scale
`}
>
{icon}

View File

@@ -104,6 +104,7 @@ export interface Message {
overridden_model?: string;
stopReason?: StreamStopReason | null;
sub_questions?: SubQuestionDetail[] | null;
is_agentic?: boolean | null;
// Streaming only
second_level_generating?: boolean;
@@ -150,6 +151,7 @@ export interface BackendMessage {
comments: any;
parentMessageId: number | null;
refined_answer_improvement: boolean | null;
is_agentic: boolean | null;
}
export interface MessageResponseIDInfo {

View File

@@ -501,6 +501,7 @@ export function processRawChatHistory(
sub_questions: subQuestions,
isImprovement:
(messageInfo.refined_answer_improvement as unknown as boolean) || false,
is_agentic: messageInfo.is_agentic,
};
messages.set(messageInfo.message_id, message);

View File

@@ -50,6 +50,9 @@ import "katex/dist/katex.min.css";
import SubQuestionsDisplay from "./SubQuestionsDisplay";
import { StatusRefinement } from "../Refinement";
import { copyAll, handleCopy } from "./copyingUtils";
import { Button } from "@/components/ui/button";
import { RefreshCw } from "lucide-react";
import { ErrorBanner, Resubmit } from "./Resubmit";
export const AgenticMessage = ({
isStreamingQuestions,
@@ -84,7 +87,9 @@ export const AgenticMessage = ({
secondLevelSubquestions,
toggleDocDisplay,
error,
resubmit,
}: {
resubmit?: () => void;
isStreamingQuestions: boolean;
isGenerating: boolean;
docSidebarToggled?: boolean;
@@ -455,7 +460,6 @@ export const AgenticMessage = ({
finalContent.length > 8) ||
(files && files.length > 0) ? (
<>
{/* <FileDisplay files={files || []} /> */}
<div className="w-full py-4 flex flex-col gap-4">
<div className="flex items-center gap-x-2 px-4">
<div className="text-black text-lg font-medium">
@@ -503,9 +507,7 @@ export const AgenticMessage = ({
content
)}
{error && (
<p className="mt-2 text-red-700 text-sm my-auto">
{error}
</p>
<ErrorBanner error={error} resubmit={resubmit} />
)}
</div>
</div>
@@ -513,15 +515,13 @@ export const AgenticMessage = ({
) : isComplete ? (
error && (
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
{error}
<ErrorBanner error={error} resubmit={resubmit} />
</p>
)
) : (
<>
{error && (
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
{error}
</p>
<ErrorBanner error={error} resubmit={resubmit} />
)}
</>
)}

View File

@@ -0,0 +1,58 @@
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import { AlertCircle } from "lucide-react";
import { Button } from "@/components/ui/button";
import { RefreshCw } from "lucide-react";
interface ResubmitProps {
resubmit: () => void;
}
export const Resubmit: React.FC<ResubmitProps> = ({ resubmit }) => {
return (
<div className="flex flex-col items-center justify-center gap-y-2 mt-4">
<p className="text-sm text-neutral-700 dark:text-neutral-300">
There was an error with the response.
</p>
<Button
onClick={resubmit}
variant="agent"
size="sm"
className="flex items-center gap-2 text-white font-medium py-2 px-4 rounded"
>
<RefreshCw className="w-4 h-4" />
Regenerate
</Button>
</div>
);
};
export const ErrorBanner = ({
error,
showStackTrace,
resubmit,
}: {
error: string;
showStackTrace?: () => void;
resubmit?: () => void;
}) => {
return (
<div className="text-red-700 mt-4 text-sm my-auto">
<Alert variant="broken">
<AlertCircle className="h-4 w-4" />
<AlertTitle>Error</AlertTitle>
<AlertDescription className="flex gap-x-2">
{error}
{showStackTrace && (
<span
className="text-red-600 hover:text-red-800 cursor-pointer underline"
onClick={showStackTrace}
>
Show stack trace
</span>
)}
</AlertDescription>
</Alert>
{resubmit && <Resubmit resubmit={resubmit} />}
</div>
);
};

View File

@@ -40,8 +40,13 @@ export function UserSettingsModal({
onClose: () => void;
defaultModel: string | null;
}) {
const { refreshUser, user, updateUserAutoScroll, updateUserShortcuts } =
useUser();
const {
refreshUser,
user,
updateUserAutoScroll,
updateUserShortcuts,
updateUserTemperatureOverrideEnabled,
} = useUser();
const containerRef = useRef<HTMLDivElement>(null);
const messageRef = useRef<HTMLDivElement>(null);
const { theme, setTheme } = useTheme();
@@ -156,11 +161,6 @@ export function UserSettingsModal({
const settings = useContext(SettingsContext);
const autoScroll = settings?.settings?.auto_scroll;
const checked =
user?.preferences?.auto_scroll === null
? autoScroll
: user?.preferences?.auto_scroll;
const handleChangePassword = async (e: React.FormEvent) => {
e.preventDefault();
if (newPassword !== confirmPassword) {
@@ -288,12 +288,26 @@ export function UserSettingsModal({
<SubLabel>Automatically scroll to new content</SubLabel>
</div>
<Switch
checked={checked}
checked={user?.preferences.auto_scroll}
onCheckedChange={(checked) => {
updateUserAutoScroll(checked);
}}
/>
</div>
<div className="flex items-center justify-between">
<div>
<h3 className="text-lg font-medium">
Temperature override
</h3>
<SubLabel>Set the temperature for the LLM</SubLabel>
</div>
<Switch
checked={user?.preferences.temperature_override_enabled}
onCheckedChange={(checked) => {
updateUserTemperatureOverrideEnabled(checked);
}}
/>
</div>
<div className="flex items-center justify-between">
<div>
<h3 className="text-lg font-medium">Prompt Shortcuts</h3>

View File

@@ -8,8 +8,10 @@ const alertVariants = cva(
{
variants: {
variant: {
broken:
"border-red-500/50 text-red-500 dark:border-red-500 [&>svg]:text-red-500 dark:border-red-900/50 dark:text-red-100 dark:dark:border-red-900 dark:[&>svg]:text-red-700 bg-red-50 dark:bg-red-950",
ark: "border-amber-500/50 text-amber-500 dark:border-amber-500 [&>svg]:text-amber-500 dark:border-amber-900/50 dark:text-amber-900 dark:dark:border-amber-900 dark:[&>svg]:text-amber-900 bg-amber-50 dark:bg-amber-950",
info: "border-black/50 dark:border-black dark:border-black/50 dark:dark:border-black",
default:
"bg-neutral-50 text-neutral-darker dark:bg-neutral-950 dark:text-text",
destructive:

View File

@@ -9,6 +9,8 @@ const buttonVariants = cva(
{
variants: {
variant: {
agent:
"bg-agent text-white hover:bg-agent-hovered dark:bg-agent dark:text-white dark:hover:bg-agent/90",
success:
"bg-green-100 text-green-600 hover:bg-green-500/90 dark:bg-green-700 dark:text-green-100 dark:hover:bg-green-600/90",
"success-reverse":

View File

@@ -13,7 +13,7 @@ interface UserContextType {
isCurator: boolean;
refreshUser: () => Promise<void>;
isCloudSuperuser: boolean;
updateUserAutoScroll: (autoScroll: boolean | null) => Promise<void>;
updateUserAutoScroll: (autoScroll: boolean) => Promise<void>;
updateUserShortcuts: (enabled: boolean) => Promise<void>;
toggleAssistantPinnedStatus: (
currentPinnedAssistantIDs: number[],
@@ -163,7 +163,7 @@ export function UserProvider({
}
};
const updateUserAutoScroll = async (autoScroll: boolean | null) => {
const updateUserAutoScroll = async (autoScroll: boolean) => {
try {
const response = await fetch("/api/auto-scroll", {
method: "PATCH",

View File

@@ -10,7 +10,7 @@ interface UserPreferences {
pinned_assistants?: number[];
default_model: string | null;
recent_assistants: number[];
auto_scroll: boolean | null;
auto_scroll: boolean;
shortcut_enabled: boolean;
temperature_override_enabled: boolean;
}