Compare commits

...

2 Commits

Author SHA1 Message Date
Yuhong Sun
f5fcc15eab Allow NLTK Failures (#1340) 2024-04-16 23:27:47 -07:00
Yuhong Sun
bd7e21a638 Lock Slack Keys (#1338) 2024-04-16 14:18:17 -07:00
4 changed files with 47 additions and 20 deletions

View File

@@ -3,7 +3,6 @@ from threading import Event
from typing import Any
from typing import cast
import nltk # type: ignore
from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -42,6 +41,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@@ -375,8 +375,7 @@ if __name__ == "__main__":
socket_client: SocketModeClient | None = None
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("punkt", quiet=True)
download_nltk_data()
while True:
try:

View File

@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
from typing import Any
from typing import cast
import nltk # type:ignore
import uvicorn
from fastapi import APIRouter
from fastapi import FastAPI
@@ -51,6 +50,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_version
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router
@@ -205,9 +205,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.info("Reranking step of search flow is enabled.")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
download_nltk_data()
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)

View File

@@ -1,6 +1,7 @@
import string
from collections.abc import Callable
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
@@ -31,21 +32,47 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
def download_nltk_data():
resources = {
"stopwords": "corpora/stopwords",
"wordnet": "corpora/wordnet",
"punkt": "tokenizers/punkt",
}
for resource_name, resource_path in resources.items():
try:
nltk.data.find(resource_path)
logger.info(f"{resource_name} is already downloaded.")
except LookupError:
try:
logger.info(f"Downloading {resource_name}...")
nltk.download(resource_name, quiet=True)
logger.info(f"{resource_name} downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download {resource_name}. Error: {e}")
def lemmatize_text(text: str) -> list[str]:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
return [lemmatizer.lemmatize(word) for word in word_tokens]
try:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
return [lemmatizer.lemmatize(word) for word in word_tokens]
except Exception:
return text.split(" ")
def remove_stop_words_and_punctuation(text: str) -> list[str]:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
try:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
except Exception:
return text.split(" ")
def query_processing(

View File

@@ -192,12 +192,15 @@ def list_slack_bot_configs(
@router.put("/admin/slack-bot/tokens")
def put_tokens(tokens: SlackBotTokens) -> None:
def put_tokens(
tokens: SlackBotTokens,
_: User | None = Depends(current_admin_user),
) -> None:
save_tokens(tokens=tokens)
@router.get("/admin/slack-bot/tokens")
def get_tokens() -> SlackBotTokens:
def get_tokens(_: User | None = Depends(current_admin_user)) -> SlackBotTokens:
try:
return fetch_tokens()
except ConfigNotFoundError: