Compare commits

...

15 Commits

Author SHA1 Message Date
rkuo-danswer
be4aef8696 Merge pull request #2629 from danswer-ai/hotfix/v0.6-prompt-overflow
Hotfix/v0.6 prompt overflow
2024-09-30 12:26:15 -07:00
rkuo-danswer
c28a8d831b Merge pull request #2630 from danswer-ai/hotfix/v0.6-jira-limit-size
Hotfix/v0.6 jira limit size
2024-09-30 12:25:59 -07:00
rkuo-danswer
3cafedcf22 Merge pull request #2631 from danswer-ai/hotfix/v0.6-heartbeat
Hotfix/v0.6 heartbeat
2024-09-30 12:25:48 -07:00
rkuo-danswer
f548164464 Merge pull request #2632 from danswer-ai/hotfix/v0.6-default-assistant
Hotfix/v0.6 default assistant
2024-09-30 12:25:37 -07:00
Richard Kuo (Danswer)
2fd557c7ea Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-prompt-overflow 2024-09-30 11:29:12 -07:00
Richard Kuo (Danswer)
3a8f06c765 Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-jira-limit-size 2024-09-30 11:28:52 -07:00
Richard Kuo (Danswer)
7f583420e2 Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-heartbeat 2024-09-30 11:28:25 -07:00
Richard Kuo (Danswer)
c3f96c6be6 Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-default-assistant 2024-09-30 11:27:56 -07:00
Richard Kuo (Danswer)
8a9e02c25a sync up branch protection rules 2024-09-30 11:13:53 -07:00
Richard Kuo (Danswer)
fdd9ae347f dummy commit to rerun github checks 2024-09-30 10:32:03 -07:00
Richard Kuo (Danswer)
85c1efcb25 add indexing heartbeat 2024-09-30 10:27:22 -07:00
Richard Kuo (Danswer)
f4d7e34fa3 add size limit to jira tickets 2024-09-30 09:58:35 -07:00
Richard Kuo (Danswer)
0817f91e6a fix default assistant 2024-09-30 09:56:10 -07:00
Richard Kuo (Danswer)
af20c13d8b Fix overflow of prompt library table 2024-09-30 09:50:30 -07:00
Richard Kuo (Danswer)
fc45354aea ssl_ca_certs should default to None, not "". (#2560) 2024-09-25 13:16:08 -07:00
21 changed files with 506 additions and 78 deletions

View File

@@ -3,7 +3,9 @@ name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
mypy-check:

View File

@@ -3,7 +3,9 @@ name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
backend-check:

View File

@@ -6,7 +6,9 @@ concurrency:
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -28,4 +28,6 @@ MacOS will likely require you to remove some quarantine attributes on some of th
After installing pre-commit, run the following command:
```bash
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
```
```
doc version 0.1

View File

@@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -103,15 +104,24 @@ def _run_indexing(
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

View File

@@ -167,7 +167,7 @@ REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # br
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
@@ -247,6 +247,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")

View File

@@ -9,6 +9,7 @@ from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{description}\n" + "\n".join(
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=semantic_rep)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),

View File

@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -123,6 +124,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -131,6 +133,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -255,7 +258,7 @@ class Chunker:
# If the chunk does not have any useable content, it will not be indexed
return chunks
def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -302,3 +305,13 @@ class Chunker:
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -24,6 +20,9 @@ logger = setup_logger()
class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""
def __init__(
self,
model_name: str,
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)
@abstractmethod
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
heartbeat,
)
@log_function_time()
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@@ -0,0 +1,41 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -95,6 +96,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -107,6 +109,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -166,6 +169,9 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -42,7 +42,7 @@ class RedisPool:
db: int = REDIS_DB_NUMBER,
password: str = REDIS_PASSWORD,
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
ssl: bool = False,
) -> redis.ConnectionPool:

View File

@@ -0,0 +1,136 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from jira.resources import Issue
from pytest_mock import MockFixture
from danswer.connectors.danswer_jira.connector import fetch_jira_issues_batch
@pytest.fixture
def mock_jira_client() -> MagicMock:
return MagicMock()
@pytest.fixture
def mock_issue_small() -> MagicMock:
issue = MagicMock()
issue.key = "SMALL-1"
issue.fields.description = "Small description"
issue.fields.comment.comments = [
MagicMock(body="Small comment 1"),
MagicMock(body="Small comment 2"),
]
issue.fields.creator.displayName = "John Doe"
issue.fields.creator.emailAddress = "john@example.com"
issue.fields.summary = "Small Issue"
issue.fields.updated = "2023-01-01T00:00:00+0000"
issue.fields.labels = []
return issue
@pytest.fixture
def mock_issue_large() -> MagicMock:
# This will be larger than 100KB
issue = MagicMock()
issue.key = "LARGE-1"
issue.fields.description = "a" * 99_000
issue.fields.comment.comments = [
MagicMock(body="Large comment " * 1000),
MagicMock(body="Another large comment " * 1000),
]
issue.fields.creator.displayName = "Jane Doe"
issue.fields.creator.emailAddress = "jane@example.com"
issue.fields.summary = "Large Issue"
issue.fields.updated = "2023-01-02T00:00:00+0000"
issue.fields.labels = []
return issue
@pytest.fixture
def patched_type() -> Callable[[Any], type]:
def _patched_type(obj: Any) -> type:
if isinstance(obj, MagicMock):
return Issue
return type(obj)
return _patched_type
@pytest.fixture
def mock_jira_api_version() -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
yield
@pytest.fixture
def patched_environment(
patched_type: type,
mock_jira_api_version: MockFixture,
) -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.type", patched_type):
yield
def test_fetch_jira_issues_batch_small_ticket(
mock_jira_client: MagicMock,
mock_issue_small: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 1
assert len(docs) == 1
assert docs[0].id.endswith("/SMALL-1")
assert "Small description" in docs[0].sections[0].text
assert "Small comment 1" in docs[0].sections[0].text
assert "Small comment 2" in docs[0].sections[0].text
def test_fetch_jira_issues_batch_large_ticket(
mock_jira_client: MagicMock,
mock_issue_large: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 1
assert len(docs) == 0 # The large ticket should be skipped
def test_fetch_jira_issues_batch_mixed_tickets(
mock_jira_client: MagicMock,
mock_issue_small: MagicMock,
mock_issue_large: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 2
assert len(docs) == 1 # Only the small ticket should be included
assert docs[0].id.endswith("/SMALL-1")
@patch("danswer.connectors.danswer_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
def test_fetch_jira_issues_batch_custom_size_limit(
mock_jira_client: MagicMock,
mock_issue_small: MagicMock,
mock_issue_large: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 2
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit

View File

@@ -0,0 +1,18 @@
from typing import Any
import pytest
from danswer.indexing.indexing_heartbeat import Heartbeat
class MockHeartbeat(Heartbeat):
def __init__(self) -> None:
self.call_count = 0
def heartbeat(self, metadata: Any = None) -> None:
self.call_count += 1
@pytest.fixture
def mock_heartbeat() -> MockHeartbeat:
return MockHeartbeat()

View File

@@ -1,11 +1,24 @@
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import DefaultIndexingEmbedder
from tests.unit.danswer.indexing.conftest import MockHeartbeat
def test_chunk_document() -> None:
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
@@ -30,18 +43,11 @@ def test_chunk_document() -> None:
],
)
embedder = DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
)
chunks = chunker.chunk(document)
chunks = chunker.chunk([document])
assert len(chunks) == 5
assert short_section_1 in chunks[0].content
@@ -49,3 +55,29 @@ def test_chunk_document() -> None:
assert short_section_4 in chunks[-1].content
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
def test_chunker_heartbeat(
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
) -> None:
document = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
heartbeat=mock_heartbeat,
)
chunks = chunker.chunk([document])
assert mock_heartbeat.call_count == 1
assert len(chunks) > 0

View File

@@ -0,0 +1,90 @@
from collections.abc import Generator
from unittest.mock import Mock
from unittest.mock import patch
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@pytest.fixture
def mock_embedding_model() -> Generator[Mock, None, None]:
with patch("danswer.indexing.embedder.EmbeddingModel") as mock:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
normalize=True,
query_prefix=None,
passage_prefix=None,
provider_type=EmbeddingProvider.OPENAI,
)
# Mock the encode method of the embedding model
mock_embedding_model.return_value.encode.side_effect = [
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings
[[7.0, 8.0, 9.0]], # Title embedding
]
# Create test input
source_doc = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunks: list[DocAwareChunk] = [
DocAwareChunk(
chunk_id=1,
blurb="This is a short section.",
content="Test chunk",
source_links={0: "link1"},
section_continuation=False,
source_document=source_doc,
title_prefix="Title: ",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
)
]
# Execute
result: list[IndexChunk] = embedder.embed_chunks(chunks)
# Assert
assert len(result) == 1
assert isinstance(result[0], IndexChunk)
assert result[0].content == "Test chunk"
assert result[0].embeddings == ChunkEmbedding(
full_embedding=[1.0, 2.0, 3.0],
mini_chunk_embeddings=[],
)
assert result[0].title_embedding == [7.0, 8.0, 9.0]
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)
# title only embedding call
mock_embedding_model.return_value.encode.assert_any_call(
["Test Document"],
text_type=EmbedTextType.PASSAGE,
)

View File

@@ -0,0 +1,80 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from danswer.db.index_attempt import IndexAttempt
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
@pytest.fixture
def mock_db_session() -> MagicMock:
return MagicMock(spec=Session)
@pytest.fixture
def mock_index_attempt() -> MagicMock:
return MagicMock(spec=IndexAttempt)
def test_indexing_heartbeat(
mock_db_session: MagicMock, mock_index_attempt: MagicMock
) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt:
mock_get_index_attempt.return_value = mock_index_attempt
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=5
)
# Test that heartbeat doesn't update before freq is reached
for _ in range(4):
heartbeat.heartbeat()
mock_db_session.commit.assert_not_called()
# Test that heartbeat updates when freq is reached
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
assert mock_index_attempt.time_updated is not None
mock_db_session.commit.assert_called_once()
# Reset mock calls
mock_db_session.reset_mock()
mock_get_index_attempt.reset_mock()
# Test that heartbeat updates again after freq more calls
for _ in range(5):
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt, patch(
"danswer.indexing.indexing_heartbeat.logger"
) as mock_logger:
mock_get_index_attempt.return_value = None
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=1
)
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
mock_logger.error.assert_called_once_with(
"Index attempt not found, this should not happen!"
)
mock_db_session.commit.assert_not_called()

View File

@@ -202,7 +202,7 @@ export const PromptLibraryTable = ({
))}
</div>
</div>
<div className="mx-auto">
<div className="mx-auto overflow-x-auto">
<Table>
<TableHead>
<TableRow>
@@ -220,7 +220,16 @@ export const PromptLibraryTable = ({
.map((item) => (
<TableRow key={item.id}>
<TableCell>{item.prompt}</TableCell>
<TableCell>{item.content}</TableCell>
<TableCell
className="
max-w-xs
overflow-hidden
text-ellipsis
break-words
"
>
{item.content}
</TableCell>
<TableCell>{item.active ? "Active" : "Inactive"}</TableCell>
<TableCell>
<button

View File

@@ -162,6 +162,9 @@ export function ChatPage({
user,
availableAssistants
);
const finalAssistants = user
? orderAssistantsForUser(visibleAssistants, user)
: visibleAssistants;
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
const [selectedAssistant, setSelectedAssistant] = useState<
@@ -216,7 +219,7 @@ export function ChatPage({
const liveAssistant =
alternativeAssistant ||
selectedAssistant ||
visibleAssistants[0] ||
finalAssistants[0] ||
availableAssistants[0];
useEffect(() => {
@@ -686,7 +689,7 @@ export function ChatPage({
useEffect(() => {
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
setSelectedAssistant(
visibleAssistants.find((persona) => persona.id === defaultAssistantId)
finalAssistants.find((persona) => persona.id === defaultAssistantId)
);
}
}, [defaultAssistantId]);
@@ -2390,10 +2393,7 @@ export function ChatPage({
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
// assistant stuff
assistantOptions={orderAssistantsForUser(
visibleAssistants,
user
)}
assistantOptions={finalAssistants}
selectedAssistant={liveAssistant}
setSelectedAssistant={onAssistantChange}
setAlternativeAssistant={setAlternativeAssistant}

View File

@@ -188,14 +188,6 @@ export async function fetchChatData(searchParams: {
!hasAnyConnectors &&
(!user || user.role === "admin");
const shouldDisplaySourcesIncompleteModal =
hasAnyConnectors &&
!shouldShowWelcomeModal &&
!ccPairs.some(
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
) &&
(!user || user.role == "admin");
// if no connectors are setup, only show personas that are pure
// passthrough and don't do any retrieval
if (!hasAnyConnectors) {