mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be4aef8696 | ||
|
|
c28a8d831b | ||
|
|
3cafedcf22 | ||
|
|
f548164464 | ||
|
|
2fd557c7ea | ||
|
|
3a8f06c765 | ||
|
|
7f583420e2 | ||
|
|
c3f96c6be6 | ||
|
|
8a9e02c25a | ||
|
|
fdd9ae347f | ||
|
|
85c1efcb25 | ||
|
|
f4d7e34fa3 | ||
|
|
0817f91e6a | ||
|
|
af20c13d8b | ||
|
|
fc45354aea |
4
.github/workflows/pr-python-checks.yml
vendored
4
.github/workflows/pr-python-checks.yml
vendored
@@ -3,7 +3,9 @@ name: Python Checks
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
mypy-check:
|
||||
|
||||
4
.github/workflows/pr-python-tests.yml
vendored
4
.github/workflows/pr-python-tests.yml
vendored
@@ -3,7 +3,9 @@ name: Python Unit Tests
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
jobs:
|
||||
backend-check:
|
||||
|
||||
4
.github/workflows/run-it.yml
vendored
4
.github/workflows/run-it.yml
vendored
@@ -6,7 +6,9 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
@@ -28,4 +28,6 @@ MacOS will likely require you to remove some quarantine attributes on some of th
|
||||
After installing pre-commit, run the following command:
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
```
|
||||
|
||||
doc version 0.1
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -103,15 +104,24 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
search_settings=search_settings,
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=index_attempt.id,
|
||||
db_session=db_session,
|
||||
# let the world know we're still making progress after
|
||||
# every 10 batches
|
||||
freq=10,
|
||||
),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt.id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
ignore_time_skip=(
|
||||
index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # br
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
@@ -247,6 +247,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
# Maximum size for Jira tickets in bytes (default: 100KB)
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from jira.resources import Issue
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
|
||||
else extract_text_from_adf(jira.raw["fields"]["description"])
|
||||
)
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{description}\n" + "\n".join(
|
||||
ticket_content = f"{description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments if comment]
|
||||
)
|
||||
|
||||
# Check ticket size
|
||||
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
|
||||
logger.info(
|
||||
f"Skipping {jira.key} because it exceeds the maximum size of "
|
||||
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
|
||||
)
|
||||
continue
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
|
||||
people = set()
|
||||
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=semantic_rep)],
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=jira.fields.summary,
|
||||
doc_updated_at=time_str_to_utc(jira.fields.updated),
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -123,6 +124,7 @@ class Chunker:
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@@ -131,6 +133,7 @@ class Chunker:
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.tokenizer = tokenizer
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
@@ -255,7 +258,7 @@ class Chunker:
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
@@ -302,3 +305,13 @@ class Chunker:
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
final_chunks.extend(self._handle_single_document(document))
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
|
||||
return final_chunks
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@@ -24,6 +20,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingEmbedder(ABC):
|
||||
"""Converts chunks into chunks with embeddings. Note that one chunk may have
|
||||
multiple embeddings associated with it."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
heartbeat: Heartbeat | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
api_url: str | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type,
|
||||
api_key,
|
||||
api_url,
|
||||
heartbeat,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.dict(),
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
@classmethod
|
||||
def from_db_search_settings(
|
||||
cls, search_settings: SearchSettings
|
||||
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=search_settings.model_name,
|
||||
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_from_search_settings(
|
||||
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
|
||||
) -> IndexingEmbedder:
|
||||
search_settings: SearchSettings | None
|
||||
if index_model_status == IndexModelStatus.PRESENT:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
elif index_model_status == IndexModelStatus.FUTURE:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
if not search_settings:
|
||||
raise RuntimeError("No secondary index configured")
|
||||
else:
|
||||
raise RuntimeError("Not supporting embedding model rollbacks")
|
||||
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
41
backend/danswer/indexing/indexing_heartbeat.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Heartbeat(abc.ABC):
|
||||
"""Useful for any long-running work that goes through a bunch of items
|
||||
and needs to occasionally give updates on progress.
|
||||
e.g. chunking, embedding, updating vespa, etc."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IndexingHeartbeat(Heartbeat):
|
||||
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
|
||||
self.cnt = 0
|
||||
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.db_session = db_session
|
||||
self.freq = freq
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.cnt += 1
|
||||
if self.cnt % self.freq == 0:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=self.db_session, index_attempt_id=self.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
index_attempt.time_updated = func.now()
|
||||
self.db_session.commit()
|
||||
else:
|
||||
logger.error("Index attempt not found, this should not happen!")
|
||||
@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -283,18 +284,10 @@ def index_doc_batch(
|
||||
return 0, 0
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = []
|
||||
for document in ctx.updatable_docs:
|
||||
chunks.extend(chunker.chunk(document=document))
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks,
|
||||
)
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass,
|
||||
enable_large_chunks=enable_large_chunks,
|
||||
# after every doc, update status in case there are a bunch of
|
||||
# really long docs
|
||||
heartbeat=IndexingHeartbeat(
|
||||
index_attempt_id=attempt_id, db_session=db_session, freq=1
|
||||
)
|
||||
if attempt_id
|
||||
else None,
|
||||
)
|
||||
|
||||
return partial(
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -95,6 +96,7 @@ class EmbeddingModel:
|
||||
api_url: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -107,6 +109,7 @@ class EmbeddingModel:
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
self.heartbeat = heartbeat
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -166,6 +169,9 @@ class EmbeddingModel:
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
if self.heartbeat:
|
||||
self.heartbeat.heartbeat()
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -42,7 +42,7 @@ class RedisPool:
|
||||
db: int = REDIS_DB_NUMBER,
|
||||
password: str = REDIS_PASSWORD,
|
||||
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
|
||||
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
|
||||
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
|
||||
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
|
||||
ssl: bool = False,
|
||||
) -> redis.ConnectionPool:
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from jira.resources import Issue
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from danswer.connectors.danswer_jira.connector import fetch_jira_issues_batch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_client() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_small() -> MagicMock:
|
||||
issue = MagicMock()
|
||||
issue.key = "SMALL-1"
|
||||
issue.fields.description = "Small description"
|
||||
issue.fields.comment.comments = [
|
||||
MagicMock(body="Small comment 1"),
|
||||
MagicMock(body="Small comment 2"),
|
||||
]
|
||||
issue.fields.creator.displayName = "John Doe"
|
||||
issue.fields.creator.emailAddress = "john@example.com"
|
||||
issue.fields.summary = "Small Issue"
|
||||
issue.fields.updated = "2023-01-01T00:00:00+0000"
|
||||
issue.fields.labels = []
|
||||
return issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_issue_large() -> MagicMock:
|
||||
# This will be larger than 100KB
|
||||
issue = MagicMock()
|
||||
issue.key = "LARGE-1"
|
||||
issue.fields.description = "a" * 99_000
|
||||
issue.fields.comment.comments = [
|
||||
MagicMock(body="Large comment " * 1000),
|
||||
MagicMock(body="Another large comment " * 1000),
|
||||
]
|
||||
issue.fields.creator.displayName = "Jane Doe"
|
||||
issue.fields.creator.emailAddress = "jane@example.com"
|
||||
issue.fields.summary = "Large Issue"
|
||||
issue.fields.updated = "2023-01-02T00:00:00+0000"
|
||||
issue.fields.labels = []
|
||||
return issue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_type() -> Callable[[Any], type]:
|
||||
def _patched_type(obj: Any) -> type:
|
||||
if isinstance(obj, MagicMock):
|
||||
return Issue
|
||||
return type(obj)
|
||||
|
||||
return _patched_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_api_version() -> Generator[Any, Any, Any]:
|
||||
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_environment(
|
||||
patched_type: type,
|
||||
mock_jira_api_version: MockFixture,
|
||||
) -> Generator[Any, Any, Any]:
|
||||
with patch("danswer.connectors.danswer_jira.connector.type", patched_type):
|
||||
yield
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_small_ticket(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_small: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 1
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
assert "Small description" in docs[0].sections[0].text
|
||||
assert "Small comment 1" in docs[0].sections[0].text
|
||||
assert "Small comment 2" in docs[0].sections[0].text
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_large_ticket(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_large: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_large]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 1
|
||||
assert len(docs) == 0 # The large ticket should be skipped
|
||||
|
||||
|
||||
def test_fetch_jira_issues_batch_mixed_tickets(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_small: MagicMock,
|
||||
mock_issue_large: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 2
|
||||
assert len(docs) == 1 # Only the small ticket should be included
|
||||
assert docs[0].id.endswith("/SMALL-1")
|
||||
|
||||
|
||||
@patch("danswer.connectors.danswer_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
|
||||
def test_fetch_jira_issues_batch_custom_size_limit(
|
||||
mock_jira_client: MagicMock,
|
||||
mock_issue_small: MagicMock,
|
||||
mock_issue_large: MagicMock,
|
||||
patched_environment: MockFixture,
|
||||
) -> None:
|
||||
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
|
||||
|
||||
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
|
||||
|
||||
assert count == 2
|
||||
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit
|
||||
18
backend/tests/unit/danswer/indexing/conftest.py
Normal file
18
backend/tests/unit/danswer/indexing/conftest.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.indexing.indexing_heartbeat import Heartbeat
|
||||
|
||||
|
||||
class MockHeartbeat(Heartbeat):
|
||||
def __init__(self) -> None:
|
||||
self.call_count = 0
|
||||
|
||||
def heartbeat(self, metadata: Any = None) -> None:
|
||||
self.call_count += 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_heartbeat() -> MockHeartbeat:
|
||||
return MockHeartbeat()
|
||||
@@ -1,11 +1,24 @@
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from tests.unit.danswer.indexing.conftest import MockHeartbeat
|
||||
|
||||
|
||||
def test_chunk_document() -> None:
|
||||
@pytest.fixture
|
||||
def embedder() -> DefaultIndexingEmbedder:
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name="intfloat/e5-base-v2",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
)
|
||||
|
||||
|
||||
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
|
||||
short_section_1 = "This is a short section."
|
||||
long_section = (
|
||||
"This is a long section that should be split into multiple chunks. " * 100
|
||||
@@ -30,18 +43,11 @@ def test_chunk_document() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
embedder = DefaultIndexingEmbedder(
|
||||
model_name="intfloat/e5-base-v2",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
)
|
||||
chunks = chunker.chunk(document)
|
||||
chunks = chunker.chunk([document])
|
||||
|
||||
assert len(chunks) == 5
|
||||
assert short_section_1 in chunks[0].content
|
||||
@@ -49,3 +55,29 @@ def test_chunk_document() -> None:
|
||||
assert short_section_4 in chunks[-1].content
|
||||
assert "tag1" in chunks[0].metadata_suffix_keyword
|
||||
assert "tag2" in chunks[0].metadata_suffix_semantic
|
||||
|
||||
|
||||
def test_chunker_heartbeat(
|
||||
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
|
||||
) -> None:
|
||||
document = Document(
|
||||
id="test_doc",
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier="Test Document",
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
doc_updated_at=None,
|
||||
sections=[
|
||||
Section(text="This is a short section.", link="link1"),
|
||||
],
|
||||
)
|
||||
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
heartbeat=mock_heartbeat,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk([document])
|
||||
|
||||
assert mock_heartbeat.call_count == 1
|
||||
assert len(chunks) > 0
|
||||
|
||||
90
backend/tests/unit/danswer/indexing/test_embedder.py
Normal file
90
backend/tests/unit/danswer/indexing/test_embedder.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_model() -> Generator[Mock, None, None]:
|
||||
with patch("danswer.indexing.embedder.EmbeddingModel") as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
|
||||
# Setup
|
||||
embedder = DefaultIndexingEmbedder(
|
||||
model_name="test-model",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
)
|
||||
|
||||
# Mock the encode method of the embedding model
|
||||
mock_embedding_model.return_value.encode.side_effect = [
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings
|
||||
[[7.0, 8.0, 9.0]], # Title embedding
|
||||
]
|
||||
|
||||
# Create test input
|
||||
source_doc = Document(
|
||||
id="test_doc",
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier="Test Document",
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
doc_updated_at=None,
|
||||
sections=[
|
||||
Section(text="This is a short section.", link="link1"),
|
||||
],
|
||||
)
|
||||
chunks: list[DocAwareChunk] = [
|
||||
DocAwareChunk(
|
||||
chunk_id=1,
|
||||
blurb="This is a short section.",
|
||||
content="Test chunk",
|
||||
source_links={0: "link1"},
|
||||
section_continuation=False,
|
||||
source_document=source_doc,
|
||||
title_prefix="Title: ",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_reference_ids=[],
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result: list[IndexChunk] = embedder.embed_chunks(chunks)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], IndexChunk)
|
||||
assert result[0].content == "Test chunk"
|
||||
assert result[0].embeddings == ChunkEmbedding(
|
||||
full_embedding=[1.0, 2.0, 3.0],
|
||||
mini_chunk_embeddings=[],
|
||||
)
|
||||
assert result[0].title_embedding == [7.0, 8.0, 9.0]
|
||||
|
||||
# Verify the embedding model was called correctly
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
texts=["Title: Test chunk"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
large_chunks_present=False,
|
||||
)
|
||||
# title only embedding call
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
["Test Document"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
)
|
||||
80
backend/tests/unit/danswer/indexing/test_heartbeat.py
Normal file
80
backend/tests/unit/danswer/indexing/test_heartbeat.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.index_attempt import IndexAttempt
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
return MagicMock(spec=Session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_index_attempt() -> MagicMock:
|
||||
return MagicMock(spec=IndexAttempt)
|
||||
|
||||
|
||||
def test_indexing_heartbeat(
|
||||
mock_db_session: MagicMock, mock_index_attempt: MagicMock
|
||||
) -> None:
|
||||
with patch(
|
||||
"danswer.indexing.indexing_heartbeat.get_index_attempt"
|
||||
) as mock_get_index_attempt:
|
||||
mock_get_index_attempt.return_value = mock_index_attempt
|
||||
|
||||
heartbeat = IndexingHeartbeat(
|
||||
index_attempt_id=1, db_session=mock_db_session, freq=5
|
||||
)
|
||||
|
||||
# Test that heartbeat doesn't update before freq is reached
|
||||
for _ in range(4):
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
# Test that heartbeat updates when freq is reached
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once_with(
|
||||
db_session=mock_db_session, index_attempt_id=1
|
||||
)
|
||||
assert mock_index_attempt.time_updated is not None
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# Reset mock calls
|
||||
mock_db_session.reset_mock()
|
||||
mock_get_index_attempt.reset_mock()
|
||||
|
||||
# Test that heartbeat updates again after freq more calls
|
||||
for _ in range(5):
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
|
||||
with patch(
|
||||
"danswer.indexing.indexing_heartbeat.get_index_attempt"
|
||||
) as mock_get_index_attempt, patch(
|
||||
"danswer.indexing.indexing_heartbeat.logger"
|
||||
) as mock_logger:
|
||||
mock_get_index_attempt.return_value = None
|
||||
|
||||
heartbeat = IndexingHeartbeat(
|
||||
index_attempt_id=1, db_session=mock_db_session, freq=1
|
||||
)
|
||||
|
||||
heartbeat.heartbeat()
|
||||
|
||||
mock_get_index_attempt.assert_called_once_with(
|
||||
db_session=mock_db_session, index_attempt_id=1
|
||||
)
|
||||
mock_logger.error.assert_called_once_with(
|
||||
"Index attempt not found, this should not happen!"
|
||||
)
|
||||
mock_db_session.commit.assert_not_called()
|
||||
@@ -202,7 +202,7 @@ export const PromptLibraryTable = ({
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div className="mx-auto">
|
||||
<div className="mx-auto overflow-x-auto">
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
@@ -220,7 +220,16 @@ export const PromptLibraryTable = ({
|
||||
.map((item) => (
|
||||
<TableRow key={item.id}>
|
||||
<TableCell>{item.prompt}</TableCell>
|
||||
<TableCell>{item.content}</TableCell>
|
||||
<TableCell
|
||||
className="
|
||||
max-w-xs
|
||||
overflow-hidden
|
||||
text-ellipsis
|
||||
break-words
|
||||
"
|
||||
>
|
||||
{item.content}
|
||||
</TableCell>
|
||||
<TableCell>{item.active ? "Active" : "Inactive"}</TableCell>
|
||||
<TableCell>
|
||||
<button
|
||||
|
||||
@@ -162,6 +162,9 @@ export function ChatPage({
|
||||
user,
|
||||
availableAssistants
|
||||
);
|
||||
const finalAssistants = user
|
||||
? orderAssistantsForUser(visibleAssistants, user)
|
||||
: visibleAssistants;
|
||||
|
||||
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
|
||||
const [selectedAssistant, setSelectedAssistant] = useState<
|
||||
@@ -216,7 +219,7 @@ export function ChatPage({
|
||||
const liveAssistant =
|
||||
alternativeAssistant ||
|
||||
selectedAssistant ||
|
||||
visibleAssistants[0] ||
|
||||
finalAssistants[0] ||
|
||||
availableAssistants[0];
|
||||
|
||||
useEffect(() => {
|
||||
@@ -686,7 +689,7 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
|
||||
setSelectedAssistant(
|
||||
visibleAssistants.find((persona) => persona.id === defaultAssistantId)
|
||||
finalAssistants.find((persona) => persona.id === defaultAssistantId)
|
||||
);
|
||||
}
|
||||
}, [defaultAssistantId]);
|
||||
@@ -2390,10 +2393,7 @@ export function ChatPage({
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
selectedDocuments={selectedDocuments}
|
||||
// assistant stuff
|
||||
assistantOptions={orderAssistantsForUser(
|
||||
visibleAssistants,
|
||||
user
|
||||
)}
|
||||
assistantOptions={finalAssistants}
|
||||
selectedAssistant={liveAssistant}
|
||||
setSelectedAssistant={onAssistantChange}
|
||||
setAlternativeAssistant={setAlternativeAssistant}
|
||||
|
||||
@@ -188,14 +188,6 @@ export async function fetchChatData(searchParams: {
|
||||
!hasAnyConnectors &&
|
||||
(!user || user.role === "admin");
|
||||
|
||||
const shouldDisplaySourcesIncompleteModal =
|
||||
hasAnyConnectors &&
|
||||
!shouldShowWelcomeModal &&
|
||||
!ccPairs.some(
|
||||
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
|
||||
) &&
|
||||
(!user || user.role == "admin");
|
||||
|
||||
// if no connectors are setup, only show personas that are pure
|
||||
// passthrough and don't do any retrieval
|
||||
if (!hasAnyConnectors) {
|
||||
|
||||
Reference in New Issue
Block a user