Compare commits

...

13 Commits

Author SHA1 Message Date
Richard Kuo (Onyx)
4f65c8ef54 fix log format, make collection copy more explicit for readability 2025-05-30 16:07:38 -07:00
Richard Kuo (Onyx)
acf1761a2c synchronize waiting, use non thread local redis locks 2025-05-30 15:26:33 -07:00
Richard Kuo (Onyx)
b55834a6a8 fix return condition 2025-05-30 15:01:08 -07:00
Richard Kuo (Onyx)
1a88d409b4 switch to punkt_tab 2025-05-30 14:49:34 -07:00
Richard Kuo (Onyx)
92324b6094 don't use self 2025-05-30 14:41:48 -07:00
Richard Kuo (Onyx)
337d642077 various fixes and notes 2025-05-30 14:27:20 -07:00
Richard Kuo (Onyx)
6c17381a7e enforce block list size limit 2025-05-30 14:22:05 -07:00
Richard Kuo (Onyx)
5b0b6acf26 .close isn't async 2025-05-30 13:10:12 -07:00
Richard Kuo (Onyx)
a8ca9ae523 safe msg get 2025-05-30 11:36:27 -07:00
Richard Kuo (Onyx)
eb0bb5180f just use if 2025-05-30 09:44:56 -07:00
Richard Kuo (Onyx)
0d1f0ede5e Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/slack-bot-2 2025-05-30 09:37:18 -07:00
Richard Kuo (Onyx)
9f0c0ee884 add logging 2025-05-30 09:36:42 -07:00
Richard Kuo (Onyx)
ec7bb2298f try fixing slack bot 2025-05-29 18:23:51 -07:00
6 changed files with 212 additions and 72 deletions

View File

@@ -62,7 +62,7 @@ def download_nltk_data() -> None:
resources = {
"stopwords": "corpora/stopwords",
# "wordnet": "corpora/wordnet", # Not in use
"punkt": "tokenizers/punkt",
"punkt_tab": "tokenizers/punkt_tab",
}
for resource_name, resource_path in resources.items():

View File

@@ -64,7 +64,7 @@ TENANT_HEARTBEAT_INTERVAL = (
15 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
60 # How long before a tenant's heartbeat expires, allowing other pods to take over
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens

View File

@@ -137,7 +137,10 @@ def handle_generate_answer_button(
raise ValueError("Missing thread_ts in the payload")
thread_messages = read_slack_thread(
channel=channel_id, thread=thread_ts, client=client.web_client
tenant_id=client._tenant_id,
channel=channel_id,
thread=thread_ts,
client=client.web_client,
)
# remove all assistant messages till we get to the last user message
# we want the new answer to be generated off of the last "question" in

View File

@@ -419,6 +419,11 @@ def handle_regular_answer(
skip_ai_feedback=skip_ai_feedback,
)
# NOTE(rkuo): Slack has a maximum block list size of 50.
# we should modify build_slack_response_blocks to respect the max
# but enforcing the hard limit here is the last resort.
all_blocks = all_blocks[:50]
try:
respond_in_thread_or_channel(
client=client,

View File

@@ -1,4 +1,3 @@
import asyncio
import os
import signal
import sys
@@ -11,8 +10,8 @@ from types import FrameType
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
import psycopg2.errors
from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
@@ -87,7 +86,7 @@ from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import check_message_limit
from onyx.onyxbot.slack.utils import decompose_action_id
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id
from onyx.onyxbot.slack.utils import get_onyx_bot_auth_ids
from onyx.onyxbot.slack.utils import read_slack_thread
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
from onyx.onyxbot.slack.utils import rephrase_slack_message
@@ -135,7 +134,7 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str] = set()
self.tenant_ids: set[str] = set()
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
@@ -146,6 +145,9 @@ class SlackbotHandler:
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
self._lock = threading.Lock()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
@@ -169,6 +171,7 @@ class SlackbotHandler:
self.acquire_thread.start()
self.heartbeat_thread.start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
@@ -194,12 +197,18 @@ class SlackbotHandler:
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self) -> None:
"""This heartbeats into redis.
NOTE(rkuo): this is not thread-safe with acquire_tenants_loop and will
occasionally exception. Fix it!
"""
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
with self._lock:
tenant_ids = self.tenant_ids.copy()
SlackbotHandler.send_heartbeats(self.pod_id, tenant_ids)
logger.debug(f"Sent heartbeats for {len(tenant_ids)} active tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
@@ -224,7 +233,7 @@ class SlackbotHandler:
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.socket_clients[tenant_bot_pair].close()
del self.socket_clients[tenant_bot_pair]
del self.slack_bot_tokens[tenant_bot_pair]
return
@@ -252,9 +261,20 @@ class SlackbotHandler:
# Close any existing connection first
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.socket_clients[tenant_bot_pair].close()
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
socket_client = self.start_socket_client(
bot.id, tenant_id, slack_bot_tokens
)
if socket_client:
# Ensure tenant is tracked as active
self.socket_clients[tenant_id, bot.id] = socket_client
logger.info(
f"Started SocketModeClient: {tenant_id=} {socket_client.bot_name=} {bot.id=}"
)
self.tenant_ids.add(tenant_id)
def acquire_tenants(self) -> None:
"""
@@ -301,8 +321,12 @@ class SlackbotHandler:
redis_client = get_redis_client(tenant_id=tenant_id)
# Acquire a Redis lock (non-blocking)
# thread_local=False because the shutdown event is handled
# on an arbitrary thread
rlock: RedisLock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
OnyxRedisLocks.SLACK_BOT_LOCK,
timeout=TENANT_LOCK_EXPIRATION,
thread_local=False,
)
lock_acquired = rlock.acquire(blocking=False)
@@ -333,6 +357,10 @@ class SlackbotHandler:
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except psycopg2.errors.UndefinedTable:
logger.error(
"Undefined table error in fetch_slack_bots. Tenant schema may need fixing."
)
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
@@ -409,10 +437,11 @@ class SlackbotHandler:
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
socket_client_list = list(self.socket_clients.items())
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
for (t_id, slack_bot_id), client in socket_client_list:
if t_id == tenant_id:
asyncio.run(client.close())
client.close()
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
@@ -423,19 +452,22 @@ class SlackbotHandler:
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)
def send_heartbeats(self) -> None:
@staticmethod
def send_heartbeats(pod_id: str, tenant_ids: set[str]) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
for tenant_id in self.tenant_ids:
logger.debug(f"Sending heartbeats for {len(tenant_ids)} active tenants")
for tenant_id in tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{pod_id}"
redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
@staticmethod
def start_socket_client(
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
) -> None:
slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
) -> TenantSocketModeClient | None:
"""Returns the socket client if this succeeds"""
socket_client: TenantSocketModeClient = _get_socket_client(
slack_bot_tokens, tenant_id, slack_bot_id
)
@@ -451,18 +483,20 @@ class SlackbotHandler:
user_info["user"]["real_name"] or user_info["user"]["name"]
)
socket_client.bot_name = bot_name
logger.info(
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
)
# logger.info(
# f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
# )
except SlackApiError as e:
# Only error out if we get a not_authed error
if "not_authed" in str(e):
self.tenant_ids.add(tenant_id)
# for some reason we want to add the tenant to the list when this happens?
logger.error(
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
"Error: {e}"
f"Authentication error - Invalid or expired credentials: "
f"{tenant_id=} {slack_bot_id=}. "
f"Error: {e}"
)
return
return None
# Log other Slack API errors but continue
logger.error(
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
@@ -478,21 +512,20 @@ class SlackbotHandler:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(
f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
)
# logger.debug(
# f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
# )
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
# logger.info(
# f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
# )
return socket_client
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
asyncio.run(client.close())
client.close()
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
@@ -503,7 +536,11 @@ class SlackbotHandler:
logger.info("Shutting down gracefully")
self.running = False
self._shutdown_event.set()
self._shutdown_event.set() # set the shutdown event
# wait for threads to detect the event and exit
self.acquire_thread.join(timeout=60.0)
self.heartbeat_thread.join(timeout=60.0)
# Stop all socket clients
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
@@ -534,7 +571,13 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
"""True to keep going, False to ignore this Slack request"""
# skip cases where the bot is disabled in the web UI
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
tenant_id = get_current_tenant_id()
bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids(
tenant_id, client.web_client
)
logger.info(f"prefilter_requests: {bot_token_user_id=} {bot_token_bot_id=}")
with get_session_with_current_tenant() as db_session:
slack_bot = fetch_slack_bot(
db_session=db_session, slack_bot_id=client.slack_bot_id
@@ -581,7 +624,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
if (
msg in _SLACK_GREETINGS_TO_IGNORE
or remove_onyx_bot_tag(msg, client=client.web_client)
or remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
in _SLACK_GREETINGS_TO_IGNORE
):
channel_specific_logger.error(
@@ -600,15 +643,38 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
return False
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids(
tenant_id, client.web_client
)
if event_type == "message":
is_onyx_bot_msg = False
is_tagged = False
event_user = event.get("user", "")
event_bot_id = event.get("bot_id", "")
# temporary debugging
if tenant_id == "tenant_i-04224818da13bf695":
logger.warning(
f"{tenant_id=} "
f"{bot_token_user_id=} "
f"{bot_token_bot_id=} "
f"{event=}"
)
is_dm = event.get("channel_type") == "im"
is_tagged = bot_tag_id and f"<@{bot_tag_id}>" in msg
is_onyx_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
if bot_token_user_id and f"<@{bot_token_user_id}>" in msg:
is_tagged = True
if bot_token_user_id and bot_token_user_id in event_user:
is_onyx_bot_msg = True
if bot_token_bot_id and bot_token_bot_id in event_bot_id:
is_onyx_bot_msg = True
# OnyxBot should never respond to itself
if is_onyx_bot_msg:
logger.info("Ignoring message from OnyxBot")
logger.info("Ignoring message from OnyxBot (self-message)")
return False
# DMs with the bot don't pick up the @OnyxBot so we have to keep the
@@ -633,7 +699,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
if (not bot_token_user_id or bot_token_user_id not in msg) and (
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
):
@@ -732,15 +798,16 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
def build_request_details(
req: SocketModeRequest, client: TenantSocketModeClient
) -> SlackMessageInfo:
tagged: bool = False
tenant_id = get_current_tenant_id()
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
msg = cast(str, event["text"])
channel = cast(str, event["channel"])
# Check for both app_mention events and messages containing bot tag
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
tagged = (event.get("type") == "app_mention") or (
event.get("type") == "message" and bot_tag_id and f"<@{bot_tag_id}>" in msg
)
bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, client.web_client)
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
sender_id = event.get("user") or None
@@ -749,7 +816,7 @@ def build_request_details(
)
email = expert_info.email if expert_info else None
msg = remove_onyx_bot_tag(msg, client=client.web_client)
msg = remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.info(f"Rephrasing Slack message. Original message: {msg}")
@@ -761,12 +828,24 @@ def build_request_details(
else:
logger.info(f"Received Slack message: {msg}")
event_type = event.get("type")
if event_type == "app_mention":
tagged = True
if event_type == "message":
if bot_token_user_id:
if f"<@{bot_token_user_id}>" in msg:
tagged = True
if tagged:
logger.debug("User tagged OnyxBot")
if thread_ts != message_ts and thread_ts is not None:
thread_messages = read_slack_thread(
channel=channel, thread=thread_ts, client=client.web_client
tenant_id=tenant_id,
channel=channel,
thread=thread_ts,
client=client.web_client,
)
else:
sender_display_name = None
@@ -843,12 +922,24 @@ def process_message(
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
tenant_id = get_current_tenant_id()
logger.debug(
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
)
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
event_type = event.get("type")
msg = cast(str, event.get("text", ""))
logger.info(
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=} "
f"{event_type=} {msg=}"
)
else:
logger.info(
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=}"
)
# Throw out requests that can't or shouldn't be handled
if not prefilter_requests(req, client):
logger.info(
f"process_message prefiltered: {tenant_id=} {req.type=} {req.envelope_id=}"
)
return
details = build_request_details(req, client)
@@ -891,6 +982,10 @@ def process_message(
if notify_no_answer:
apologize_for_fail(details, client)
logger.info(
f"process_message finished: success={not failed} {tenant_id=} {req.type=} {req.envelope_id=}"
)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
response = SocketModeResponse(envelope_id=req.envelope_id)

View File

@@ -2,6 +2,7 @@ import logging
import random
import re
import string
import threading
import time
import uuid
from collections.abc import Generator
@@ -48,17 +49,38 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
slack_token_user_ids: dict[str, str | None] = {}
slack_token_bot_ids: dict[str, str | None] = {}
slack_token_lock = threading.Lock()
_DANSWER_BOT_SLACK_BOT_ID: str | None = None
_DANSWER_BOT_MESSAGE_COUNT: int = 0
_DANSWER_BOT_COUNT_START_TIME: float = time.time()
def get_onyx_bot_slack_bot_id(web_client: WebClient) -> Any:
global _DANSWER_BOT_SLACK_BOT_ID
if _DANSWER_BOT_SLACK_BOT_ID is None:
_DANSWER_BOT_SLACK_BOT_ID = web_client.auth_test().get("user_id")
return _DANSWER_BOT_SLACK_BOT_ID
def get_onyx_bot_auth_ids(
tenant_id: str, web_client: WebClient
) -> tuple[str | None, str | None]:
"""Returns a tuple of user_id and bot_id."""
user_id: str | None
bot_id: str | None
global slack_token_user_ids
global slack_token_bot_ids
with slack_token_lock:
user_id = slack_token_user_ids.get(tenant_id)
bot_id = slack_token_bot_ids.get(tenant_id)
if user_id is None or bot_id is None:
response = web_client.auth_test()
user_id = response.get("user_id")
bot_id = response.get("bot_id")
with slack_token_lock:
slack_token_user_ids[tenant_id] = user_id
slack_token_bot_ids[tenant_id] = bot_id
return user_id, bot_id
def check_message_limit() -> bool:
@@ -146,9 +168,9 @@ def update_emote_react(
return
def remove_onyx_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_onyx_bot_slack_bot_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s*", "", message_str)
def remove_onyx_bot_tag(tenant_id: str, message_str: str, client: WebClient) -> str:
bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, web_client=client)
return re.sub(rf"<@{bot_token_user_id}>\s*", "", message_str)
def _check_for_url_in_block(block: Block) -> bool:
@@ -218,7 +240,8 @@ def respond_in_thread_or_channel(
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
blocks_str = str(blocks)[:1024] # truncate block logging
logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
@@ -255,7 +278,8 @@ def respond_in_thread_or_channel(
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
blocks_str = str(blocks)[:1024] # truncate block logging
logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
@@ -518,7 +542,7 @@ def fetch_user_semantic_id_from_id(
def read_slack_thread(
channel: str, thread: str, client: WebClient
tenant_id: str, channel: str, thread: str, client: WebClient
) -> list[ThreadMessage]:
thread_messages: list[ThreadMessage] = []
response = client.conversations_replies(channel=channel, ts=thread)
@@ -532,9 +556,22 @@ def read_slack_thread(
)
message_type = MessageType.USER
else:
self_slack_bot_id = get_onyx_bot_slack_bot_id(client)
blocks: Any
if reply.get("user") == self_slack_bot_id:
is_onyx_bot_response = False
reply_user = reply.get("user")
reply_bot_id = reply.get("bot_id")
self_slack_bot_user_id, self_slack_bot_bot_id = get_onyx_bot_auth_ids(
tenant_id, client
)
if reply_user is not None and reply_user == self_slack_bot_user_id:
is_onyx_bot_response = True
if reply_bot_id is not None and reply_bot_id == self_slack_bot_bot_id:
is_onyx_bot_response = True
if is_onyx_bot_response:
# OnyxBot response
message_type = MessageType.ASSISTANT
user_sem_id = "Assistant"
@@ -576,7 +613,7 @@ def read_slack_thread(
logger.warning("Skipping Slack thread message, no text found")
continue
message = remove_onyx_bot_tag(message, client=client)
message = remove_onyx_bot_tag(tenant_id, message, client=client)
thread_messages.append(
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
)