mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
28 Commits
fix_openap
...
v0.6.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d4ecf817f | ||
|
|
881f814030 | ||
|
|
40013454f8 | ||
|
|
11080d3d69 | ||
|
|
acaff41457 | ||
|
|
dc35e9f6da | ||
|
|
cdace77209 | ||
|
|
7e75e7c49e | ||
|
|
fc18fd5b19 | ||
|
|
6dc9649e30 | ||
|
|
65c1918159 | ||
|
|
7678a356f5 | ||
|
|
be4aef8696 | ||
|
|
c28a8d831b | ||
|
|
3cafedcf22 | ||
|
|
f548164464 | ||
|
|
2fd557c7ea | ||
|
|
3a8f06c765 | ||
|
|
7f583420e2 | ||
|
|
c3f96c6be6 | ||
|
|
8a9e02c25a | ||
|
|
fdd9ae347f | ||
|
|
85c1efcb25 | ||
|
|
f4d7e34fa3 | ||
|
|
0817f91e6a | ||
|
|
af20c13d8b | ||
|
|
d724da9474 | ||
|
|
fc45354aea |
4
.github/workflows/pr-python-checks.yml
vendored
4
.github/workflows/pr-python-checks.yml
vendored
@@ -3,7 +3,9 @@ name: Python Checks
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
mypy-check:
|
||||
|
||||
@@ -15,6 +15,9 @@ env:
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
# Jira
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
4
.github/workflows/pr-python-tests.yml
vendored
4
.github/workflows/pr-python-tests.yml
vendored
@@ -3,7 +3,9 @@ name: Python Unit Tests
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
|
||||
4
.github/workflows/run-it.yml
vendored
4
.github/workflows/run-it.yml
vendored
@@ -6,7 +6,9 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
@@ -28,4 +28,6 @@ MacOS will likely require you to remove some quarantine attributes on some of th
|
||||
After installing pre-commit, run the following command:
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
```
|
||||
|
||||
doc version 0.1
|
||||
|
||||
@@ -101,7 +101,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
|
||||
@@ -55,6 +55,6 @@ COPY ./shared_configs /app/shared_configs
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -103,15 +104,24 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
search_settings=search_settings,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
ignore_time_skip=(
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # br
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
@@ -247,6 +247,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
# Maximum size for Jira tickets in bytes (default: 100KB)
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
if self.jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
|
||||
jql=f"project = {self.jira_project}",
|
||||
jql=f"project = {quoted_project}",
|
||||
start_index=start_ind,
|
||||
jira_client=self.jira_client,
|
||||
batch_size=self.batch_size,
|
||||
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
# Quote the project name to handle reserved words
|
||||
quoted_project = f'"{self.jira_project}"'
|
||||
jql = (
|
||||
f"project = {self.jira_project} AND "
|
||||
f"project = {quoted_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -123,6 +124,7 @@ class Chunker:
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -131,6 +133,7 @@ class Chunker:
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.tokenizer = tokenizer
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
@@ -255,7 +258,7 @@ class Chunker:
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
@@ -302,3 +305,13 @@ class Chunker:
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
final_chunks.extend(self._handle_single_document(document))
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
|
||||
return final_chunks
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@@ -24,6 +20,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingEmbedder(ABC):
|
||||
"""Converts chunks into chunks with embeddings. Note that one chunk may have
|
||||
multiple embeddings associated with it."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
heartbeat: Heartbeat | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
api_url: str | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type,
|
||||
api_key,
|
||||
api_url,
|
||||
heartbeat,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.dict(),
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
@classmethod
|
||||
def from_db_search_settings(
|
||||
cls, search_settings: SearchSettings
|
||||
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=search_settings.model_name,
|
||||
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_from_search_settings(
|
||||
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
|
||||
) -> IndexingEmbedder:
|
||||
search_settings: SearchSettings | None
|
||||
if index_model_status == IndexModelStatus.PRESENT:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
elif index_model_status == IndexModelStatus.FUTURE:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
if not search_settings:
|
||||
raise RuntimeError("No secondary index configured")
|
||||
else:
|
||||
raise RuntimeError("Not supporting embedding model rollbacks")
|
||||
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Heartbeat(abc.ABC):
|
||||
"""Useful for any long-running work that goes through a bunch of items
|
||||
and needs to occasionally give updates on progress.
|
||||
e.g. chunking, embedding, updating vespa, etc."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IndexingHeartbeat(Heartbeat):
|
||||
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
|
||||
self.cnt = 0
|
||||
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.db_session = db_session
|
||||
self.freq = freq
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.cnt += 1
|
||||
if self.cnt % self.freq == 0:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=self.db_session, index_attempt_id=self.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
index_attempt.time_updated = func.now()
|
||||
self.db_session.commit()
|
||||
else:
|
||||
logger.error("Index attempt not found, this should not happen!")
|
||||
@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -283,18 +284,10 @@ def index_doc_batch(
|
||||
return 0, 0
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = []
|
||||
for document in ctx.updatable_docs:
|
||||
chunks.extend(chunker.chunk(document=document))
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks,
|
||||
)
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass,
|
||||
enable_large_chunks=enable_large_chunks,
|
||||
# after every doc, update status in case there are a bunch of
|
||||
# really long docs
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=attempt_id, db_session=db_session, freq=1
|
||||
)
|
||||
if attempt_id
|
||||
else None,
|
||||
)
|
||||
|
||||
return partial(
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -95,6 +96,7 @@ class EmbeddingModel:
|
||||
api_url: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -107,6 +109,7 @@ class EmbeddingModel:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -166,6 +169,9 @@ class EmbeddingModel:
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -42,7 +42,7 @@ class RedisPool:
|
||||
db: int = REDIS_DB_NUMBER,
|
||||
password: str = REDIS_PASSWORD,
|
||||
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
|
||||
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
|
||||
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
|
||||
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
|
||||
ssl: bool = False,
|
||||
) -> redis.ConnectionPool:
|
||||
|
||||
@@ -18,6 +18,7 @@ from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import DocumentSet__UserGroup
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import Persona__UserGroup
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
@@ -33,6 +34,93 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_ids: list[UUID] | None = None,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
where_clause = User__UserGroup.user_group_id == user_group_id
|
||||
if user_ids:
|
||||
where_clause &= User__UserGroup.user_id.in_(user_ids)
|
||||
|
||||
user__user_group_relationships = db_session.scalars(
|
||||
select(User__UserGroup).where(where_clause)
|
||||
).all()
|
||||
for user__user_group_relationship in user__user_group_relationships:
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Credential__UserGroup).filter(
|
||||
Credential__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_persona__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
token_rate_limit__user_group_relationships = db_session.scalars(
|
||||
select(TokenRateLimit__UserGroup).where(
|
||||
TokenRateLimit__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
).all()
|
||||
for (
|
||||
token_rate_limit__user_group_relationship
|
||||
) in token_rate_limit__user_group_relationships:
|
||||
db_session.delete(token_rate_limit__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, outdated_only: bool
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
stmt = select(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
|
||||
)
|
||||
if outdated_only:
|
||||
stmt = stmt.where(
|
||||
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
|
||||
)
|
||||
user_group__cc_pair_relationships = db_session.scalars(stmt)
|
||||
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
|
||||
db_session.delete(user_group__cc_pair_relationship)
|
||||
|
||||
|
||||
def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.execute(
|
||||
delete(DocumentSet__UserGroup).where(
|
||||
DocumentSet__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
@@ -286,42 +374,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
return db_user_group
|
||||
|
||||
|
||||
def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_ids: list[UUID] | None = None,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
where_clause = User__UserGroup.user_group_id == user_group_id
|
||||
if user_ids:
|
||||
where_clause &= User__UserGroup.user_id.in_(user_ids)
|
||||
|
||||
user__user_group_relationships = db_session.scalars(
|
||||
select(User__UserGroup).where(where_clause)
|
||||
).all()
|
||||
for user__user_group_relationship in user__user_group_relationships:
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(Credential__UserGroup).filter(
|
||||
Credential__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
@@ -476,21 +528,6 @@ def update_user_group(
|
||||
return db_user_group
|
||||
|
||||
|
||||
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
token_rate_limit__user_group_relationships = db_session.scalars(
|
||||
select(TokenRateLimit__UserGroup).where(
|
||||
TokenRateLimit__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
).all()
|
||||
for (
|
||||
token_rate_limit__user_group_relationship
|
||||
) in token_rate_limit__user_group_relationships:
|
||||
db_session.delete(token_rate_limit__user_group_relationship)
|
||||
|
||||
|
||||
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
db_user_group = db_session.scalar(stmt)
|
||||
@@ -499,16 +536,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
_cleanup_credential__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
_cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_token_rate_limit__user_group_relationships__no_commit(
|
||||
_cleanup_document_set__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_persona__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
_cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
outdated_only=False,
|
||||
)
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
@@ -517,31 +569,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, outdated_only: bool
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
stmt = select(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
|
||||
)
|
||||
if outdated_only:
|
||||
stmt = stmt.where(
|
||||
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
|
||||
)
|
||||
user_group__cc_pair_relationships = db_session.scalars(stmt)
|
||||
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
|
||||
db_session.delete(user_group__cc_pair_relationship)
|
||||
|
||||
|
||||
def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.execute(
|
||||
delete(DocumentSet__UserGroup).where(
|
||||
DocumentSet__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
"""
|
||||
This assumes that all the fk cleanup has already been done.
|
||||
"""
|
||||
db_session.delete(user_group)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
|
||||
@@ -553,29 +586,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user_group__cc_pair_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group.id,
|
||||
outdated_only=False,
|
||||
)
|
||||
_cleanup_document_set__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
|
||||
# need to flush so that we don't get a foreign key error when deleting the user group row
|
||||
db_session.flush()
|
||||
|
||||
db_session.delete(user_group)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_user_group_cc_pair_relationship__no_commit(
|
||||
cc_pair_id: int, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ asyncpg==0.27.0
|
||||
atlassian-python-api==3.37.0
|
||||
beautifulsoup4==4.12.2
|
||||
boto3==1.34.84
|
||||
celery==5.3.4
|
||||
celery==5.5.0b4
|
||||
chardet==5.2.0
|
||||
dask==2023.8.1
|
||||
ddtrace==2.6.5
|
||||
|
||||
@@ -7,12 +7,13 @@ logfile=/var/log/supervisord.log
|
||||
# Cannot place this in Celery for now because Celery must run as a single process (see note below)
|
||||
# Indexing uses multi-processing to speed things up
|
||||
[program:document_indexing]
|
||||
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true,LOG_FILE_NAME=document_indexing
|
||||
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true
|
||||
command=python danswer/background/update.py
|
||||
stdout_logfile=/var/log/document_indexing.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
|
||||
# Background jobs that must be run async due to long time to completion
|
||||
# NOTE: due to an issue with Celery + SQLAlchemy
|
||||
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
|
||||
@@ -31,7 +32,8 @@ command=celery -A danswer.background.celery.celery_run:celery_app worker
|
||||
--loglevel=INFO
|
||||
--logfile=/var/log/celery_worker_supervisor.log
|
||||
-Q celery,vespa_metadata_sync,connector_deletion
|
||||
environment=LOG_FILE_NAME=celery_worker
|
||||
stdout_logfile=/var/log/celery_worker.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
@@ -39,7 +41,8 @@ autorestart=true
|
||||
[program:celery_beat]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app beat
|
||||
--logfile=/var/log/celery_beat_supervisor.log
|
||||
environment=LOG_FILE_NAME=celery_beat
|
||||
stdout_logfile=/var/log/celery_beat.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
|
||||
# Listens for Slack messages and responds with answers
|
||||
@@ -48,7 +51,8 @@ redirect_stderr=true
|
||||
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
|
||||
[program:slack_bot]
|
||||
command=python danswer/danswerbot/slack/listener.py
|
||||
environment=LOG_FILE_NAME=slack_bot
|
||||
stdout_logfile=/var/log/slack_bot.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startretries=5
|
||||
@@ -58,12 +62,10 @@ startsecs=60
|
||||
# No log rotation here, since it's stdout it's handled by the Docker container logging
|
||||
[program:log-redirect-handler]
|
||||
command=tail -qF
|
||||
/var/log/document_indexing_info.log
|
||||
/var/log/celery_beat_supervisor.log
|
||||
/var/log/celery_worker_supervisor.log
|
||||
/var/log/celery_beat_debug.log
|
||||
/var/log/celery_worker_debug.log
|
||||
/var/log/slack_bot_debug.log
|
||||
/var/log/document_indexing.log
|
||||
/var/log/celery_beat.log
|
||||
/var/log/celery_worker.log
|
||||
/var/log/slack_bot.log
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
redirect_stderr=true
|
||||
|
||||
48
backend/tests/daily/connectors/jira/test_jira_basic.py
Normal file
48
backend/tests/daily/connectors/jira/test_jira_basic.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jira_connector() -> JiraConnector:
|
||||
connector = JiraConnector(
|
||||
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
|
||||
comment_email_blacklist=[],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
"jira_api_token": os.environ["JIRA_API_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
doc_batch_generator = jira_connector.poll_source(0, time.time())
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 1
|
||||
|
||||
doc = doc_batch[0]
|
||||
|
||||
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
assert doc.semantic_identifier == "test123small"
|
||||
assert doc.source == DocumentSource.JIRA
|
||||
assert doc.metadata == {"priority": "Medium", "status": "Backlog"}
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.title is None
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
assert section.text == "example_text\n"
|
||||
assert section.link == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
@@ -81,6 +81,6 @@ RUN pip install --no-cache-dir --upgrade \
|
||||
-r /tmp/dev-requirements.txt
|
||||
COPY ./tests/integration /app/tests/integration
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
CMD ["pytest", "-s", "/app/tests/integration"]
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from jira.resources import Issue
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from danswer.connectors.danswer_jira.connector import fetch_jira_issues_batch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_client() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_small() -> MagicMock:
|
||||
issue = MagicMock()
|
||||
issue.key = "SMALL-1"
|
||||
issue.fields.description = "Small description"
|
||||
issue.fields.comment.comments = [
|
||||
MagicMock(body="Small comment 1"),
|
||||
MagicMock(body="Small comment 2"),
|
||||
]
|
||||
issue.fields.creator.displayName = "John Doe"
|
||||
issue.fields.creator.emailAddress = "john@example.com"
|
||||
issue.fields.summary = "Small Issue"
|
||||
issue.fields.updated = "2023-01-01T00:00:00+0000"
|
||||
issue.fields.labels = []
|
||||
return issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_large() -> MagicMock:
|
||||
# This will be larger than 100KB
|
||||
issue = MagicMock()
|
||||
issue.key = "LARGE-1"
|
||||
issue.fields.description = "a" * 99_000
|
||||
issue.fields.comment.comments = [
|
||||
MagicMock(body="Large comment " * 1000),
|
||||
MagicMock(body="Another large comment " * 1000),
|
||||
]
|
||||
issue.fields.creator.displayName = "Jane Doe"
|
||||
issue.fields.creator.emailAddress = "jane@example.com"
|
||||
issue.fields.summary = "Large Issue"
|
||||
issue.fields.updated = "2023-01-02T00:00:00+0000"
|
||||
issue.fields.labels = []
|
||||
return issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_type() -> Callable[[Any], type]:
|
||||
def _patched_type(obj: Any) -> type:
|
||||
if isinstance(obj, MagicMock):
|
||||
return Issue
|
||||
return type(obj)
|
||||
|
||||
return _patched_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_api_version() -> Generator[Any, Any, Any]:
|
||||
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_environment(
|
||||
patched_type: type,
|
||||
mock_jira_api_version: MockFixture,
|
||||
) -> Generator[Any, Any, Any]:
|
||||
with patch("danswer.connectors.danswer_jira.connector.type", patched_type):
|
||||
yield
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_small_ticket(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_small: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 1
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
assert "Small description" in docs[0].sections[0].text
|
||||
assert "Small comment 1" in docs[0].sections[0].text
|
||||
assert "Small comment 2" in docs[0].sections[0].text
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_large_ticket(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_large: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_large]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 1
|
||||
assert len(docs) == 0 # The large ticket should be skipped
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_mixed_tickets(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_small: MagicMock,
|
||||
mock_issue_large: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 2
|
||||
assert len(docs) == 1 # Only the small ticket should be included
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
|
||||
|
||||
@patch("danswer.connectors.danswer_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
|
||||
def test_fetch_jira_issues_batch_custom_size_limit(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_small: MagicMock,
|
||||
mock_issue_large: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 2
|
||||
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit
|
||||
18
backend/tests/unit/danswer/indexing/conftest.py
Normal file
18
backend/tests/unit/danswer/indexing/conftest.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
|
||||
|
||||
class MockHeartbeat(Heartbeat):
|
||||
def __init__(self) -> None:
|
||||
self.call_count = 0
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.call_count += 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_heartbeat() -> MockHeartbeat:
|
||||
return MockHeartbeat()
|
||||
@@ -1,11 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from tests.unit.danswer.indexing.conftest import MockHeartbeat
|
||||
|
||||
|
||||
def test_chunk_document() -> None:
|
||||
@pytest.fixture
|
||||
def embedder() -> DefaultIndexingEmbedder:
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name="intfloat/e5-base-v2",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
)
|
||||
|
||||
|
||||
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
|
||||
short_section_1 = "This is a short section."
|
||||
long_section = (
|
||||
"This is a long section that should be split into multiple chunks. " * 100
|
||||
@@ -30,18 +43,11 @@ def test_chunk_document() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
embedder = DefaultIndexingEmbedder(
|
||||
model_name="intfloat/e5-base-v2",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
)
|
||||
chunks = chunker.chunk(document)
|
||||
chunks = chunker.chunk([document])
|
||||
|
||||
assert len(chunks) == 5
|
||||
assert short_section_1 in chunks[0].content
|
||||
@@ -49,3 +55,29 @@ def test_chunk_document() -> None:
|
||||
assert short_section_4 in chunks[-1].content
|
||||
assert "tag1" in chunks[0].metadata_suffix_keyword
|
||||
assert "tag2" in chunks[0].metadata_suffix_semantic
|
||||
|
||||
|
||||
def test_chunker_heartbeat(
|
||||
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
|
||||
) -> None:
|
||||
document = Document(
|
||||
id="test_doc",
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier="Test Document",
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
doc_updated_at=None,
|
||||
sections=[
|
||||
Section(text="This is a short section.", link="link1"),
|
||||
],
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
heartbeat=mock_heartbeat,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk([document])
|
||||
|
||||
assert mock_heartbeat.call_count == 1
|
||||
assert len(chunks) > 0
|
||||
|
||||
90
backend/tests/unit/danswer/indexing/test_embedder.py
Normal file
90
backend/tests/unit/danswer/indexing/test_embedder.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_model() -> Generator[Mock, None, None]:
|
||||
with patch("danswer.indexing.embedder.EmbeddingModel") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
|
||||
# Setup
|
||||
embedder = DefaultIndexingEmbedder(
|
||||
model_name="test-model",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
)
|
||||
|
||||
# Mock the encode method of the embedding model
|
||||
mock_embedding_model.return_value.encode.side_effect = [
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings
|
||||
[[7.0, 8.0, 9.0]], # Title embedding
|
||||
]
|
||||
|
||||
# Create test input
|
||||
source_doc = Document(
|
||||
id="test_doc",
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier="Test Document",
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
doc_updated_at=None,
|
||||
sections=[
|
||||
Section(text="This is a short section.", link="link1"),
|
||||
],
|
||||
)
|
||||
chunks: list[DocAwareChunk] = [
|
||||
DocAwareChunk(
|
||||
chunk_id=1,
|
||||
blurb="This is a short section.",
|
||||
content="Test chunk",
|
||||
source_links={0: "link1"},
|
||||
section_continuation=False,
|
||||
source_document=source_doc,
|
||||
title_prefix="Title: ",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_reference_ids=[],
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result: list[IndexChunk] = embedder.embed_chunks(chunks)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], IndexChunk)
|
||||
assert result[0].content == "Test chunk"
|
||||
assert result[0].embeddings == ChunkEmbedding(
|
||||
full_embedding=[1.0, 2.0, 3.0],
|
||||
mini_chunk_embeddings=[],
|
||||
)
|
||||
assert result[0].title_embedding == [7.0, 8.0, 9.0]
|
||||
|
||||
# Verify the embedding model was called correctly
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
texts=["Title: Test chunk"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
large_chunks_present=False,
|
||||
)
|
||||
# title only embedding call
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
["Test Document"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
)
|
||||
80
backend/tests/unit/danswer/indexing/test_heartbeat.py
Normal file
80
backend/tests/unit/danswer/indexing/test_heartbeat.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import IndexAttempt
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_attempt() -> MagicMock:
|
||||
return MagicMock(spec=IndexAttempt)
|
||||
|
||||
|
||||
def test_indexing_heartbeat(
|
||||
mock_db_session: MagicMock, mock_index_attempt: MagicMock
|
||||
) -> None:
|
||||
with patch(
|
||||
"danswer.indexing.indexing_heartbeat.get_index_attempt"
|
||||
) as mock_get_index_attempt:
|
||||
mock_get_index_attempt.return_value = mock_index_attempt
|
||||
|
||||
heartbeat = IndexingHeartbeat(
|
||||
index_attempt_id=1, db_session=mock_db_session, freq=5
|
||||
)
|
||||
|
||||
# Test that heartbeat doesn't update before freq is reached
|
||||
for _ in range(4):
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
# Test that heartbeat updates when freq is reached
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once_with(
|
||||
db_session=mock_db_session, index_attempt_id=1
|
||||
)
|
||||
assert mock_index_attempt.time_updated is not None
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# Reset mock calls
|
||||
mock_db_session.reset_mock()
|
||||
mock_get_index_attempt.reset_mock()
|
||||
|
||||
# Test that heartbeat updates again after freq more calls
|
||||
for _ in range(5):
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
|
||||
with patch(
|
||||
"danswer.indexing.indexing_heartbeat.get_index_attempt"
|
||||
) as mock_get_index_attempt, patch(
|
||||
"danswer.indexing.indexing_heartbeat.logger"
|
||||
) as mock_logger:
|
||||
mock_get_index_attempt.return_value = None
|
||||
|
||||
heartbeat = IndexingHeartbeat(
|
||||
index_attempt_id=1, db_session=mock_db_session, freq=1
|
||||
)
|
||||
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once_with(
|
||||
db_session=mock_db_session, index_attempt_id=1
|
||||
)
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Index attempt not found, this should not happen!"
|
||||
)
|
||||
mock_db_session.commit.assert_not_called()
|
||||
@@ -25,10 +25,10 @@ COPY . .
|
||||
RUN npm ci
|
||||
|
||||
# needed to get the `standalone` dir we expect later
|
||||
ENV NEXT_PRIVATE_STANDALONE true
|
||||
ENV NEXT_PRIVATE_STANDALONE=true
|
||||
|
||||
# Disable automatic telemetry collection
|
||||
ENV NEXT_TELEMETRY_DISABLED 1
|
||||
ENV NEXT_TELEMETRY_DISABLED=1
|
||||
|
||||
# Environment variables must be present at build time
|
||||
# https://github.com/vercel/next.js/discussions/14030
|
||||
@@ -77,7 +77,7 @@ RUN rm -rf /usr/local/lib/node_modules
|
||||
# ENV NODE_ENV production
|
||||
|
||||
# Disable automatic telemetry collection
|
||||
ENV NEXT_TELEMETRY_DISABLED 1
|
||||
ENV NEXT_TELEMETRY_DISABLED=1
|
||||
|
||||
# Don't run production as root
|
||||
RUN addgroup --system --gid 1001 nodejs
|
||||
|
||||
@@ -202,7 +202,7 @@ export const PromptLibraryTable = ({
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div className="mx-auto">
|
||||
<div className="mx-auto overflow-x-auto">
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
@@ -220,7 +220,16 @@ export const PromptLibraryTable = ({
|
||||
.map((item) => (
|
||||
<TableRow key={item.id}>
|
||||
<TableCell>{item.prompt}</TableCell>
|
||||
<TableCell>{item.content}</TableCell>
|
||||
<TableCell
|
||||
className="
|
||||
max-w-xs
|
||||
overflow-hidden
|
||||
text-ellipsis
|
||||
break-words
|
||||
"
|
||||
>
|
||||
{item.content}
|
||||
</TableCell>
|
||||
<TableCell>{item.active ? "Active" : "Inactive"}</TableCell>
|
||||
<TableCell>
|
||||
<button
|
||||
|
||||
@@ -162,6 +162,9 @@ export function ChatPage({
|
||||
user,
|
||||
availableAssistants
|
||||
);
|
||||
const finalAssistants = user
|
||||
? orderAssistantsForUser(visibleAssistants, user)
|
||||
: visibleAssistants;
|
||||
|
||||
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
|
||||
const [selectedAssistant, setSelectedAssistant] = useState<
|
||||
@@ -216,7 +219,7 @@ export function ChatPage({
|
||||
const liveAssistant =
|
||||
alternativeAssistant ||
|
||||
selectedAssistant ||
|
||||
visibleAssistants[0] ||
|
||||
finalAssistants[0] ||
|
||||
availableAssistants[0];
|
||||
|
||||
useEffect(() => {
|
||||
@@ -686,7 +689,7 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
|
||||
setSelectedAssistant(
|
||||
visibleAssistants.find((persona) => persona.id === defaultAssistantId)
|
||||
finalAssistants.find((persona) => persona.id === defaultAssistantId)
|
||||
);
|
||||
}
|
||||
}, [defaultAssistantId]);
|
||||
@@ -2390,10 +2393,7 @@ export function ChatPage({
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
selectedDocuments={selectedDocuments}
|
||||
// assistant stuff
|
||||
assistantOptions={orderAssistantsForUser(
|
||||
visibleAssistants,
|
||||
user
|
||||
)}
|
||||
assistantOptions={finalAssistants}
|
||||
selectedAssistant={liveAssistant}
|
||||
setSelectedAssistant={onAssistantChange}
|
||||
setAlternativeAssistant={setAlternativeAssistant}
|
||||
|
||||
@@ -188,14 +188,6 @@ export async function fetchChatData(searchParams: {
|
||||
!hasAnyConnectors &&
|
||||
(!user || user.role === "admin");
|
||||
|
||||
const shouldDisplaySourcesIncompleteModal =
|
||||
hasAnyConnectors &&
|
||||
!shouldShowWelcomeModal &&
|
||||
!ccPairs.some(
|
||||
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
|
||||
) &&
|
||||
(!user || user.role == "admin");
|
||||
|
||||
// if no connectors are setup, only show personas that are pure
|
||||
// passthrough and don't do any retrieval
|
||||
if (!hasAnyConnectors) {
|
||||
|
||||
Reference in New Issue
Block a user