Compare commits

...

28 Commits

Author SHA1 Message Date
rkuo-danswer
9d4ecf817f Merge pull request #2755 from danswer-ai/hotfix/v0.6-supervisord-logs
backport update indexing and slack bot to use stdout options (#2752)
2024-10-09 18:26:19 -07:00
Richard Kuo (Danswer)
881f814030 backport update indexing and slack bot to use stdout options (#2752) 2024-10-09 18:05:52 -07:00
rkuo-danswer
40013454f8 Merge pull request #2745 from danswer-ai/hotfix/v0.6-docker-kv-deprecation
Merge hotfix/v0.6-docker-kv-deprecation into release/v0.6
2024-10-09 12:54:51 -07:00
rkuo-danswer
11080d3d69 Merge pull request #2717 from danswer-ai/bugfix/docker-legacy-key-value-format
Fix all LegacyKeyValueFormat docker warnings
2024-10-09 19:41:36 +00:00
rkuo-danswer
acaff41457 Merge pull request #2646 from danswer-ai/slack-fix
Added quotes to project name to handle reserved words (#2639)
2024-10-01 12:14:36 -07:00
hagen-danswer
dc35e9f6da Added quotes to project name to handle reserved words (#2639) 2024-10-01 11:29:01 -07:00
rkuo-danswer
cdace77209 Merge pull request #2645 from danswer-ai/hotfix/v0.6.1-celery
bump celery for potential fix to redis disconnection behavior
2024-10-01 11:05:05 -07:00
Richard Kuo (Danswer)
7e75e7c49e bump celery for potential fix to redis disconnection behavior 2024-10-01 10:43:09 -07:00
Richard Kuo (Danswer)
fc18fd5b19 Revert "update celery"
This reverts commit 6dc9649e30.
2024-10-01 10:41:37 -07:00
Richard Kuo (Danswer)
6dc9649e30 update celery 2024-10-01 10:39:33 -07:00
rkuo-danswer
65c1918159 Merge pull request #2611 from danswer-ai/fixes/user-group-fk
Group fk fix
2024-10-01 09:49:50 -07:00
Richard Kuo (Danswer)
7678a356f5 Merge branch 'release/v0.6' of https://github.com/danswer-ai/danswer into fixes/user-group-fk 2024-10-01 09:17:09 -07:00
rkuo-danswer
be4aef8696 Merge pull request #2629 from danswer-ai/hotfix/v0.6-prompt-overflow
Hotfix/v0.6 prompt overflow
2024-09-30 12:26:15 -07:00
rkuo-danswer
c28a8d831b Merge pull request #2630 from danswer-ai/hotfix/v0.6-jira-limit-size
Hotfix/v0.6 jira limit size
2024-09-30 12:25:59 -07:00
rkuo-danswer
3cafedcf22 Merge pull request #2631 from danswer-ai/hotfix/v0.6-heartbeat
Hotfix/v0.6 heartbeat
2024-09-30 12:25:48 -07:00
rkuo-danswer
f548164464 Merge pull request #2632 from danswer-ai/hotfix/v0.6-default-assistant
Hotfix/v0.6 default assistant
2024-09-30 12:25:37 -07:00
Richard Kuo (Danswer)
2fd557c7ea Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-prompt-overflow 2024-09-30 11:29:12 -07:00
Richard Kuo (Danswer)
3a8f06c765 Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-jira-limit-size 2024-09-30 11:28:52 -07:00
Richard Kuo (Danswer)
7f583420e2 Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-heartbeat 2024-09-30 11:28:25 -07:00
Richard Kuo (Danswer)
c3f96c6be6 Merge branch 'release/v0.6' of github.com:danswer-ai/danswer into hotfix/v0.6-default-assistant 2024-09-30 11:27:56 -07:00
Richard Kuo (Danswer)
8a9e02c25a sync up branch protection rules 2024-09-30 11:13:53 -07:00
Richard Kuo (Danswer)
fdd9ae347f dummy commit to rerun github checks 2024-09-30 10:32:03 -07:00
Richard Kuo (Danswer)
85c1efcb25 add indexing heartbeat 2024-09-30 10:27:22 -07:00
Richard Kuo (Danswer)
f4d7e34fa3 add size limit to jira tickets 2024-09-30 09:58:35 -07:00
Richard Kuo (Danswer)
0817f91e6a fix default assistant 2024-09-30 09:56:10 -07:00
Richard Kuo (Danswer)
af20c13d8b Fix overflow of prompt library table 2024-09-30 09:50:30 -07:00
hagen-danswer
d724da9474 Group fk fix 2024-09-30 09:37:18 -07:00
Richard Kuo (Danswer)
fc45354aea ssl_ca_certs should default to None, not "". (#2560) 2024-09-25 13:16:08 -07:00
30 changed files with 694 additions and 199 deletions

View File

@@ -3,7 +3,9 @@ name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
mypy-check:

View File

@@ -15,6 +15,9 @@ env:
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
jobs:
connectors-check:

View File

@@ -3,7 +3,9 @@ name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
jobs:
backend-check:

View File

@@ -6,7 +6,9 @@ concurrency:
on:
merge_group:
pull_request:
branches: [ main ]
branches:
- main
- 'release/**'
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -28,4 +28,6 @@ MacOS will likely require you to remove some quarantine attributes on some of th
After installing pre-commit, run the following command:
```bash
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
```
```
doc version 0.1

View File

@@ -101,7 +101,7 @@ COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connect
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
# Default command which does nothing
# This container is used by api server and background which specify their own CMD

View File

@@ -55,6 +55,6 @@ COPY ./shared_configs /app/shared_configs
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@@ -103,15 +104,24 @@ def _run_indexing(
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

View File

@@ -167,7 +167,7 @@ REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # br
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None)
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
@@ -247,6 +247,10 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
if ignored_tag
]
# Maximum size for Jira tickets in bytes (default: 100KB)
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
)
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")

View File

@@ -9,6 +9,7 @@ from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -134,10 +135,18 @@ def fetch_jira_issues_batch(
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{description}\n" + "\n".join(
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
@@ -180,7 +189,7 @@ def fetch_jira_issues_batch(
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=semantic_rep)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
@@ -236,10 +245,12 @@ class JiraConnector(LoadConnector, PollConnector):
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {self.jira_project}",
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
@@ -267,8 +278,10 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {self.jira_project} AND "
f"project = {quoted_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)

View File

@@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -123,6 +124,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -131,6 +133,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -255,7 +258,7 @@ class Chunker:
# If the chunk does not have any useable content, it will not be indexed
return chunks
def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -302,3 +305,13 @@ class Chunker:
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -24,6 +20,9 @@ logger = setup_logger()
class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""
def __init__(
self,
model_name: str,
@@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)
@abstractmethod
@@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
heartbeat,
)
@log_function_time()
@@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@@ -0,0 +1,41 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -95,6 +96,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -107,6 +109,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -166,6 +169,9 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -42,7 +42,7 @@ class RedisPool:
db: int = REDIS_DB_NUMBER,
password: str = REDIS_PASSWORD,
max_connections: int = REDIS_POOL_MAX_CONNECTIONS,
ssl_ca_certs: str = REDIS_SSL_CA_CERTS,
ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS,
ssl_cert_reqs: str = REDIS_SSL_CERT_REQS,
ssl: bool = False,
) -> redis.ConnectionPool:

View File

@@ -18,6 +18,7 @@ from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import DocumentSet__UserGroup
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import Persona__UserGroup
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
@@ -33,6 +34,93 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
logger = setup_logger()
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_persona__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def validate_user_creation_permissions(
db_session: Session,
user: User | None,
@@ -286,42 +374,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
return db_user_group
def _cleanup_user__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
user_ids: list[UUID] | None = None,
) -> None:
"""NOTE: does not commit the transaction."""
where_clause = User__UserGroup.user_group_id == user_group_id
if user_ids:
where_clause &= User__UserGroup.user_id.in_(user_ids)
user__user_group_relationships = db_session.scalars(
select(User__UserGroup).where(where_clause)
).all()
for user__user_group_relationship in user__user_group_relationships:
db_session.delete(user__user_group_relationship)
def _cleanup_credential__user_group_relationships__no_commit(
db_session: Session,
user_group_id: int,
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(Credential__UserGroup).filter(
Credential__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _cleanup_llm_provider__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.user_group_id == user_group_id
).delete(synchronize_session=False)
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session: Session, user_group_id: int
) -> None:
@@ -476,21 +528,6 @@ def update_user_group(
return db_user_group
def _cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
token_rate_limit__user_group_relationships = db_session.scalars(
select(TokenRateLimit__UserGroup).where(
TokenRateLimit__UserGroup.user_group_id == user_group_id
)
).all()
for (
token_rate_limit__user_group_relationship
) in token_rate_limit__user_group_relationships:
db_session.delete(token_rate_limit__user_group_relationship)
def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
db_user_group = db_session.scalar(stmt)
@@ -499,16 +536,31 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_credential__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_mark_user_group__cc_pair_relationships_outdated__no_commit(
_cleanup_token_rate_limit__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_token_rate_limit__user_group_relationships__no_commit(
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_persona__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group_id,
outdated_only=False,
)
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -517,31 +569,12 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session.commit()
def _cleanup_user_group__cc_pair_relationships__no_commit(
db_session: Session, user_group_id: int, outdated_only: bool
) -> None:
"""NOTE: does not commit the transaction."""
stmt = select(UserGroup__ConnectorCredentialPair).where(
UserGroup__ConnectorCredentialPair.user_group_id == user_group_id
)
if outdated_only:
stmt = stmt.where(
UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712
)
user_group__cc_pair_relationships = db_session.scalars(stmt)
for user_group__cc_pair_relationship in user_group__cc_pair_relationships:
db_session.delete(user_group__cc_pair_relationship)
def _cleanup_document_set__user_group_relationships__no_commit(
db_session: Session, user_group_id: int
) -> None:
"""NOTE: does not commit the transaction."""
db_session.execute(
delete(DocumentSet__UserGroup).where(
DocumentSet__UserGroup.user_group_id == user_group_id
)
)
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
"""
This assumes that all the fk cleanup has already been done.
"""
db_session.delete(user_group)
db_session.commit()
def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None:
@@ -553,29 +586,6 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
db_session.commit()
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
_cleanup_llm_provider__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
_cleanup_user_group__cc_pair_relationships__no_commit(
db_session=db_session,
user_group_id=user_group.id,
outdated_only=False,
)
_cleanup_document_set__user_group_relationships__no_commit(
db_session=db_session, user_group_id=user_group.id
)
# need to flush so that we don't get a foreign key error when deleting the user group row
db_session.flush()
db_session.delete(user_group)
db_session.commit()
def delete_user_group_cc_pair_relationship__no_commit(
cc_pair_id: int, db_session: Session
) -> None:

View File

@@ -4,7 +4,7 @@ asyncpg==0.27.0
atlassian-python-api==3.37.0
beautifulsoup4==4.12.2
boto3==1.34.84
celery==5.3.4
celery==5.5.0b4
chardet==5.2.0
dask==2023.8.1
ddtrace==2.6.5

View File

@@ -7,12 +7,13 @@ logfile=/var/log/supervisord.log
# Cannot place this in Celery for now because Celery must run as a single process (see note below)
# Indexing uses multi-processing to speed things up
[program:document_indexing]
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true,LOG_FILE_NAME=document_indexing
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true
command=python danswer/background/update.py
stdout_logfile=/var/log/document_indexing.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
# Background jobs that must be run async due to long time to completion
# NOTE: due to an issue with Celery + SQLAlchemy
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
@@ -31,7 +32,8 @@ command=celery -A danswer.background.celery.celery_run:celery_app worker
--loglevel=INFO
--logfile=/var/log/celery_worker_supervisor.log
-Q celery,vespa_metadata_sync,connector_deletion
environment=LOG_FILE_NAME=celery_worker
stdout_logfile=/var/log/celery_worker.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
@@ -39,7 +41,8 @@ autorestart=true
[program:celery_beat]
command=celery -A danswer.background.celery.celery_run:celery_app beat
--logfile=/var/log/celery_beat_supervisor.log
environment=LOG_FILE_NAME=celery_beat
stdout_logfile=/var/log/celery_beat.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
# Listens for Slack messages and responds with answers
@@ -48,7 +51,8 @@ redirect_stderr=true
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
[program:slack_bot]
command=python danswer/danswerbot/slack/listener.py
environment=LOG_FILE_NAME=slack_bot
stdout_logfile=/var/log/slack_bot.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startretries=5
@@ -58,12 +62,10 @@ startsecs=60
# No log rotation here, since it's stdout it's handled by the Docker container logging
[program:log-redirect-handler]
command=tail -qF
/var/log/document_indexing_info.log
/var/log/celery_beat_supervisor.log
/var/log/celery_worker_supervisor.log
/var/log/celery_beat_debug.log
/var/log/celery_worker_debug.log
/var/log/slack_bot_debug.log
/var/log/document_indexing.log
/var/log/celery_beat.log
/var/log/celery_worker.log
/var/log/slack_bot.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true

View File

@@ -0,0 +1,48 @@
import os
import time
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.danswer_jira.connector import JiraConnector
@pytest.fixture
def jira_connector() -> JiraConnector:
connector = JiraConnector(
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
comment_email_blacklist=[],
)
connector.load_credentials(
{
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
"jira_api_token": os.environ["JIRA_API_TOKEN"],
}
)
return connector
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
doc_batch_generator = jira_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 1
doc = doc_batch[0]
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
assert doc.semantic_identifier == "test123small"
assert doc.source == DocumentSource.JIRA
assert doc.metadata == {"priority": "Medium", "status": "Backlog"}
assert doc.secondary_owners is None
assert doc.title is None
assert doc.from_ingestion_api is False
assert doc.additional_info is None
assert len(doc.sections) == 1
section = doc.sections[0]
assert section.text == "example_text\n"
assert section.link == "https://danswerai.atlassian.net/browse/AS-2"

View File

@@ -81,6 +81,6 @@ RUN pip install --no-cache-dir --upgrade \
-r /tmp/dev-requirements.txt
COPY ./tests/integration /app/tests/integration
ENV PYTHONPATH /app
ENV PYTHONPATH=/app
CMD ["pytest", "-s", "/app/tests/integration"]

View File

@@ -0,0 +1,136 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from jira.resources import Issue
from pytest_mock import MockFixture
from danswer.connectors.danswer_jira.connector import fetch_jira_issues_batch
@pytest.fixture
def mock_jira_client() -> MagicMock:
return MagicMock()
@pytest.fixture
def mock_issue_small() -> MagicMock:
issue = MagicMock()
issue.key = "SMALL-1"
issue.fields.description = "Small description"
issue.fields.comment.comments = [
MagicMock(body="Small comment 1"),
MagicMock(body="Small comment 2"),
]
issue.fields.creator.displayName = "John Doe"
issue.fields.creator.emailAddress = "john@example.com"
issue.fields.summary = "Small Issue"
issue.fields.updated = "2023-01-01T00:00:00+0000"
issue.fields.labels = []
return issue
@pytest.fixture
def mock_issue_large() -> MagicMock:
# This will be larger than 100KB
issue = MagicMock()
issue.key = "LARGE-1"
issue.fields.description = "a" * 99_000
issue.fields.comment.comments = [
MagicMock(body="Large comment " * 1000),
MagicMock(body="Another large comment " * 1000),
]
issue.fields.creator.displayName = "Jane Doe"
issue.fields.creator.emailAddress = "jane@example.com"
issue.fields.summary = "Large Issue"
issue.fields.updated = "2023-01-02T00:00:00+0000"
issue.fields.labels = []
return issue
@pytest.fixture
def patched_type() -> Callable[[Any], type]:
def _patched_type(obj: Any) -> type:
if isinstance(obj, MagicMock):
return Issue
return type(obj)
return _patched_type
@pytest.fixture
def mock_jira_api_version() -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.JIRA_API_VERSION", "2"):
yield
@pytest.fixture
def patched_environment(
patched_type: type,
mock_jira_api_version: MockFixture,
) -> Generator[Any, Any, Any]:
with patch("danswer.connectors.danswer_jira.connector.type", patched_type):
yield
def test_fetch_jira_issues_batch_small_ticket(
mock_jira_client: MagicMock,
mock_issue_small: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 1
assert len(docs) == 1
assert docs[0].id.endswith("/SMALL-1")
assert "Small description" in docs[0].sections[0].text
assert "Small comment 1" in docs[0].sections[0].text
assert "Small comment 2" in docs[0].sections[0].text
def test_fetch_jira_issues_batch_large_ticket(
mock_jira_client: MagicMock,
mock_issue_large: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 1
assert len(docs) == 0 # The large ticket should be skipped
def test_fetch_jira_issues_batch_mixed_tickets(
mock_jira_client: MagicMock,
mock_issue_small: MagicMock,
mock_issue_large: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 2
assert len(docs) == 1 # Only the small ticket should be included
assert docs[0].id.endswith("/SMALL-1")
@patch("danswer.connectors.danswer_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
def test_fetch_jira_issues_batch_custom_size_limit(
mock_jira_client: MagicMock,
mock_issue_small: MagicMock,
mock_issue_large: MagicMock,
patched_environment: MockFixture,
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs, count = fetch_jira_issues_batch("project = TEST", 0, mock_jira_client)
assert count == 2
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit

View File

@@ -0,0 +1,18 @@
from typing import Any
import pytest
from danswer.indexing.indexing_heartbeat import Heartbeat
class MockHeartbeat(Heartbeat):
def __init__(self) -> None:
self.call_count = 0
def heartbeat(self, metadata: Any = None) -> None:
self.call_count += 1
@pytest.fixture
def mock_heartbeat() -> MockHeartbeat:
return MockHeartbeat()

View File

@@ -1,11 +1,24 @@
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import DefaultIndexingEmbedder
from tests.unit.danswer.indexing.conftest import MockHeartbeat
def test_chunk_document() -> None:
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
@@ -30,18 +43,11 @@ def test_chunk_document() -> None:
],
)
embedder = DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
)
chunks = chunker.chunk(document)
chunks = chunker.chunk([document])
assert len(chunks) == 5
assert short_section_1 in chunks[0].content
@@ -49,3 +55,29 @@ def test_chunk_document() -> None:
assert short_section_4 in chunks[-1].content
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
def test_chunker_heartbeat(
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
) -> None:
document = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
heartbeat=mock_heartbeat,
)
chunks = chunker.chunk([document])
assert mock_heartbeat.call_count == 1
assert len(chunks) > 0

View File

@@ -0,0 +1,90 @@
from collections.abc import Generator
from unittest.mock import Mock
from unittest.mock import patch
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@pytest.fixture
def mock_embedding_model() -> Generator[Mock, None, None]:
with patch("danswer.indexing.embedder.EmbeddingModel") as mock:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
normalize=True,
query_prefix=None,
passage_prefix=None,
provider_type=EmbeddingProvider.OPENAI,
)
# Mock the encode method of the embedding model
mock_embedding_model.return_value.encode.side_effect = [
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings
[[7.0, 8.0, 9.0]], # Title embedding
]
# Create test input
source_doc = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunks: list[DocAwareChunk] = [
DocAwareChunk(
chunk_id=1,
blurb="This is a short section.",
content="Test chunk",
source_links={0: "link1"},
section_continuation=False,
source_document=source_doc,
title_prefix="Title: ",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
)
]
# Execute
result: list[IndexChunk] = embedder.embed_chunks(chunks)
# Assert
assert len(result) == 1
assert isinstance(result[0], IndexChunk)
assert result[0].content == "Test chunk"
assert result[0].embeddings == ChunkEmbedding(
full_embedding=[1.0, 2.0, 3.0],
mini_chunk_embeddings=[],
)
assert result[0].title_embedding == [7.0, 8.0, 9.0]
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)
# title only embedding call
mock_embedding_model.return_value.encode.assert_any_call(
["Test Document"],
text_type=EmbedTextType.PASSAGE,
)

View File

@@ -0,0 +1,80 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from danswer.db.index_attempt import IndexAttempt
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
@pytest.fixture
def mock_db_session() -> MagicMock:
return MagicMock(spec=Session)
@pytest.fixture
def mock_index_attempt() -> MagicMock:
return MagicMock(spec=IndexAttempt)
def test_indexing_heartbeat(
mock_db_session: MagicMock, mock_index_attempt: MagicMock
) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt:
mock_get_index_attempt.return_value = mock_index_attempt
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=5
)
# Test that heartbeat doesn't update before freq is reached
for _ in range(4):
heartbeat.heartbeat()
mock_db_session.commit.assert_not_called()
# Test that heartbeat updates when freq is reached
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
assert mock_index_attempt.time_updated is not None
mock_db_session.commit.assert_called_once()
# Reset mock calls
mock_db_session.reset_mock()
mock_get_index_attempt.reset_mock()
# Test that heartbeat updates again after freq more calls
for _ in range(5):
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt, patch(
"danswer.indexing.indexing_heartbeat.logger"
) as mock_logger:
mock_get_index_attempt.return_value = None
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=1
)
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
mock_logger.error.assert_called_once_with(
"Index attempt not found, this should not happen!"
)
mock_db_session.commit.assert_not_called()

View File

@@ -25,10 +25,10 @@ COPY . .
RUN npm ci
# needed to get the `standalone` dir we expect later
ENV NEXT_PRIVATE_STANDALONE true
ENV NEXT_PRIVATE_STANDALONE=true
# Disable automatic telemetry collection
ENV NEXT_TELEMETRY_DISABLED 1
ENV NEXT_TELEMETRY_DISABLED=1
# Environment variables must be present at build time
# https://github.com/vercel/next.js/discussions/14030
@@ -77,7 +77,7 @@ RUN rm -rf /usr/local/lib/node_modules
# ENV NODE_ENV production
# Disable automatic telemetry collection
ENV NEXT_TELEMETRY_DISABLED 1
ENV NEXT_TELEMETRY_DISABLED=1
# Don't run production as root
RUN addgroup --system --gid 1001 nodejs

View File

@@ -202,7 +202,7 @@ export const PromptLibraryTable = ({
))}
</div>
</div>
<div className="mx-auto">
<div className="mx-auto overflow-x-auto">
<Table>
<TableHead>
<TableRow>
@@ -220,7 +220,16 @@ export const PromptLibraryTable = ({
.map((item) => (
<TableRow key={item.id}>
<TableCell>{item.prompt}</TableCell>
<TableCell>{item.content}</TableCell>
<TableCell
className="
max-w-xs
overflow-hidden
text-ellipsis
break-words
"
>
{item.content}
</TableCell>
<TableCell>{item.active ? "Active" : "Inactive"}</TableCell>
<TableCell>
<button

View File

@@ -162,6 +162,9 @@ export function ChatPage({
user,
availableAssistants
);
const finalAssistants = user
? orderAssistantsForUser(visibleAssistants, user)
: visibleAssistants;
const existingChatSessionAssistantId = selectedChatSession?.persona_id;
const [selectedAssistant, setSelectedAssistant] = useState<
@@ -216,7 +219,7 @@ export function ChatPage({
const liveAssistant =
alternativeAssistant ||
selectedAssistant ||
visibleAssistants[0] ||
finalAssistants[0] ||
availableAssistants[0];
useEffect(() => {
@@ -686,7 +689,7 @@ export function ChatPage({
useEffect(() => {
if (messageHistory.length === 0 && chatSessionIdRef.current === null) {
setSelectedAssistant(
visibleAssistants.find((persona) => persona.id === defaultAssistantId)
finalAssistants.find((persona) => persona.id === defaultAssistantId)
);
}
}, [defaultAssistantId]);
@@ -2390,10 +2393,7 @@ export function ChatPage({
showDocs={() => setDocumentSelection(true)}
selectedDocuments={selectedDocuments}
// assistant stuff
assistantOptions={orderAssistantsForUser(
visibleAssistants,
user
)}
assistantOptions={finalAssistants}
selectedAssistant={liveAssistant}
setSelectedAssistant={onAssistantChange}
setAlternativeAssistant={setAlternativeAssistant}

View File

@@ -188,14 +188,6 @@ export async function fetchChatData(searchParams: {
!hasAnyConnectors &&
(!user || user.role === "admin");
const shouldDisplaySourcesIncompleteModal =
hasAnyConnectors &&
!shouldShowWelcomeModal &&
!ccPairs.some(
(ccPair) => ccPair.has_successful_run && ccPair.docs_indexed > 0
) &&
(!user || user.role == "admin");
// if no connectors are setup, only show personas that are pure
// passthrough and don't do any retrieval
if (!hasAnyConnectors) {