Compare commits

..

52 Commits

Author SHA1 Message Date
pablodanswer
ca418fdcf2 order seeding 2024-11-27 12:16:42 -08:00
pablodanswer
4a1230f028 proper no assistant typing + no assistant modal 2024-11-27 09:55:17 -08:00
pablodanswer
28e2b78b2e Fix search dropdown (#3269)
* validate dropdown

* validate

* update organization

* move to utils
2024-11-27 16:10:07 +00:00
Emerson Gomes
0553062ac6 Adds icons for Google Gemini models and custom model icons for L… (#3218)
* Add description for Google Gemini models and custom model icons for LiteLLM (OpenAI) proxied models

* Adds Vertex AI aliases for Claude

---------

Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2024-11-26 10:13:21 -08:00
hagen-danswer
284e375ba3 Merge pull request #3257 from danswer-ai/minor-perm-sync
Improved logging for confluence doc sync and robust user creation
2024-11-26 09:59:38 -08:00
hagen-danswer
1f2f7d0ac2 Improved logging for confluence doc sync and robust user creation 2024-11-26 08:51:15 -08:00
pablodanswer
2ecc28b57d remove unused stripe promise (#3248) 2024-11-26 01:50:39 +00:00
rkuo-danswer
77cf9b3539 improve messaging and UI around cleanup of leftover index attempts (#3247)
* improve messaging and UI around cleanup of leftover index attempts

* add tag on init
2024-11-25 22:27:14 +00:00
Weves
076ce2ebd0 Saml fix 2024-11-25 09:12:43 -08:00
pablodanswer
b625ee32a7 File handling cleanup (#3240)
* fix google sites connector

* minior cleanup

* rm comments
2024-11-25 04:06:47 +00:00
Richard Kuo (Danswer)
c32b93fcc3 increase indexing worker concurrency to 3 2024-11-24 18:11:58 -08:00
pablodanswer
1c8476072e Assistant cleanup (#3236)
* minor cleanup

* ensure users don't modify built-in attributes of assistants

* update sidebar

* k

* update update flow + assistant creation
2024-11-25 00:13:34 +00:00
Chris Weaver
7573416ca1 Fix API keys for MIT users (#3237) 2024-11-24 16:55:19 -08:00
Yuhong Sun
86d8666481 Add Test Case 2024-11-24 15:42:14 -08:00
Yuhong Sun
8abcde91d4 Fix Test (#3242) 2024-11-24 14:31:28 -08:00
Yuhong Sun
3466451d51 Fix Prompt for Non Function Calling LLMs (#3241) 2024-11-24 14:16:57 -08:00
Yuhong Sun
413891f143 Token Level Log (#3238) 2024-11-23 18:41:50 -08:00
Yuhong Sun
7a0a4d4b79 Remove Deprecated Endpoints (#3235) 2024-11-23 14:44:23 -08:00
Yuhong Sun
a3439605a5 Remove Dead Code (#3234) 2024-11-23 14:31:59 -08:00
pablodanswer
694e79f5e1 minor enforcement of CSV length for internal processing (#3109) 2024-11-23 21:05:30 +00:00
pablodanswer
5dfafc8612 minor calendar cleanup (#3219) 2024-11-23 21:01:05 +00:00
Yuhong Sun
62a4aa10db Refactor Search (#3233) 2024-11-23 13:42:54 -08:00
Yuhong Sun
a357cdc4c9 Remove Dead Code (#3232) 2024-11-23 13:21:27 -08:00
Yuhong Sun
84615abfdd Seeding (#3231) 2024-11-23 13:12:42 -08:00
pablodanswer
8ae6b1960b Bugfix/usage report (#3075)
* fix pagination

* update side

* fixed query history

* minor update

* minor update

* typing
2024-11-23 20:11:39 +00:00
James Jordan
d9b87bbbc2 Fixed 400 error when author of ticket is no longer an active user in a Zendesk account. (#3168) 2024-11-23 12:15:38 -08:00
Sanju Lokuhitige
a0065b01af Update CONTRIBUTING.md (#3112)
fix Formatting and Linting hyperlink
2024-11-23 12:13:23 -08:00
pablodanswer
c5306148a3 Ensure daterange not consistently re rendered (#3229)
* ensure daterange not consistently re rendered

* minor clean up
2024-11-23 19:35:00 +00:00
hagen-danswer
1e17934de4 Merge pull request #3214 from danswer-ai/fix-slack-ui
cleaned up new slack bot creation
2024-11-23 10:53:47 -08:00
pablodanswer
93add96ccc Various Nits (#3228) 2024-11-23 10:53:24 -08:00
rkuo-danswer
3a466a4b08 add minimal retries to confluence probe (#3222)
* add minimal retries to confluence probe

* name variable correctly
2024-11-23 17:11:15 +00:00
hagen-danswer
85cbd9caed Increased slim doc batch size for confluence connector (#3221) 2024-11-23 00:42:15 +00:00
pablodanswer
9dc23bf3e7 revert to previous doc select logic (#3217)
* revert to previous doc select logic

* k
2024-11-22 23:26:53 +00:00
hagen-danswer
e32809f7ca moved it outside 2024-11-22 14:59:58 -08:00
hagen-danswer
3e58f9f8ab fixed ugly stuff 2024-11-22 14:39:55 -08:00
pablodanswer
2381c8d498 Refresh all assistants on assistant refresh (#3216)
* k

* k
2024-11-22 22:38:23 +00:00
hagen-danswer
c6dadb24dc cleaned up new slack bot creation 2024-11-22 11:53:51 -08:00
hagen-danswer
5dc07d4178 Each section is now cleaned before being chunked (#3210)
* Each section is now cleaned before being chunked

* k

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-11-22 19:06:19 +00:00
Chris Weaver
129c8f8faf Add start/end date ability for query history as CSV endpoint (#3211) 2024-11-22 18:29:13 +00:00
pablodanswer
67bfcabbc5 llm provider causing re render in effect (#3205)
* llm provider causing re render in effect

* clean

* unused

* k
2024-11-22 16:53:24 +00:00
rkuo-danswer
9819aa977a implement double check pattern for error conditions (#3201)
* Move unfenced check to check_for_indexing. implement a double check pattern for all indexing error checks

* improved commenting

* exclusions
2024-11-22 04:21:02 +00:00
hagen-danswer
8d5b8a4028 Merge pull request #3202 from danswer-ai/toggled_chat_default
Update default sidebar toggle
2024-11-21 19:53:05 -08:00
pablodanswer
682319d2e9 Bugfix/curator interface (#3198)
* mystery solved

* update config

* update

* update

* update user role

* remove values
2024-11-22 02:33:09 +00:00
hagen-danswer
fe1400aa36 replace deprecated confluence group api endpoint (#3197)
* replace deprecated confluence group api endpoint

* reworked it

* properly escaped the user query

* less passing around is_cloud

* done
2024-11-22 01:51:29 +00:00
pablodanswer
e3573b2bc1 add comment 2024-11-21 17:11:11 -08:00
pablodanswer
35b5c44cc7 update default sidebar toggle 2024-11-21 17:09:56 -08:00
rkuo-danswer
5eddc89b5a merge indexing and heartbeat callbacks (and associated lock reacquisi… (#3178)
* merge indexing and heartbeat callbacks (and associated lock reacquisition). no db updates

* review fixes
2024-11-21 23:48:58 +00:00
hagen-danswer
9a492ceb6d admins cant be set as curator on backend (#3194)
* set-curator

* updated error
2024-11-21 23:33:29 +00:00
rkuo-danswer
3c54ae9de9 Bugfix/redis wait (#3169)
* rename to payload

* log redis info replication on primary worker startup

* fix mypy

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-21 23:11:00 +00:00
pablodanswer
13f08f3ebb Horizontal scrollbar (#3195)
* clean horizontal scrollbar

* account for additional edge case
2024-11-21 22:08:21 +00:00
pablodanswer
bd9f15854f provider fix (#3187)
* clean horizontal scrollbar

* provider fix

* ensure proper migration

* k

* update migration

* Revert "clean horizontal scrollbar"

This reverts commit fa592a1b7a.
2024-11-21 22:08:16 +00:00
pablodanswer
366aa2a8ea quick fix (#3200) 2024-11-21 14:07:55 -08:00
182 changed files with 1769 additions and 1510 deletions

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -0,0 +1,59 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -0,0 +1,45 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
@@ -80,8 +80,8 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
@@ -609,7 +609,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -618,13 +618,21 @@ async def optional_user_(
async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
async def double_check_user(
@@ -910,8 +918,8 @@ def get_oauth_router(
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
@@ -921,7 +929,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

View File

@@ -1,5 +1,6 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -14,14 +15,16 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -95,6 +98,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
@@ -153,13 +165,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
continue
failure_reason = (
f"Orphaned index attempt found on startup: "
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_failed(attempt.id, db_session, failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@worker_ready.connect

View File

@@ -4,7 +4,6 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -17,6 +16,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -78,7 +78,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
@@ -111,10 +111,15 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -19,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_pool import get_redis_client
@@ -118,7 +118,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletionFenceData(
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)

View File

@@ -29,7 +29,7 @@ from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
@@ -66,9 +66,9 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
if not source_sync_period:
return True

View File

@@ -3,6 +3,7 @@ from datetime import timezone
from http import HTTPStatus
from time import sleep
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
@@ -16,7 +17,6 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -33,6 +33,8 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
@@ -42,9 +44,11 @@ from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
@@ -57,7 +61,7 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class RunIndexingCallback(RunIndexingCallbackInterface):
class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
stop_key: str,
@@ -73,6 +77,7 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
@@ -80,15 +85,17 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
return True
return False
def progress(self, amount: int) -> None:
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"IndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
@@ -97,6 +104,54 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@@ -107,7 +162,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -117,6 +172,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -135,13 +191,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
@@ -198,6 +259,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -207,6 +291,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -311,10 +400,11 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -365,6 +455,8 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -385,13 +477,16 @@ def try_creating_indexing_task(
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(None)
task_logger.exception(
f"Unexpected exception: "
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -409,7 +504,7 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -417,7 +512,7 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -428,7 +523,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -436,7 +531,7 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -460,7 +555,7 @@ def connector_indexing_proxy_task(
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
f"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -472,7 +567,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -480,6 +575,38 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -534,6 +661,7 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -541,18 +669,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -575,7 +703,7 @@ def connector_indexing_task(
)
break
lock = r.lock(
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
@@ -584,7 +712,7 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None
@@ -619,7 +747,7 @@ def connector_indexing_task(
)
# define a callback class
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,

View File

@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -277,7 +277,7 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,

View File

@@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -47,13 +46,10 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -649,20 +645,26 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # completion signal not set ... check for errors
# If we get here, and then the task both sets the completion signal and finishes,
# we will incorrectly abort the task. We must check result state, then check
# get_completion again to avoid the race condition.
if result_state in READY_STATES:
if status_int is None: # inner signal not set ... possible error
result_state = result.state
if (
result_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
@@ -697,37 +699,6 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if r.exists(fence_key):
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
@@ -779,25 +750,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_sync={n_permissions_sync} "
)
# Fail any index attempts in the DB that don't have fences
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)

View File

@@ -1,7 +1,5 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -31,7 +29,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
@@ -42,19 +40,6 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -106,7 +91,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -138,13 +123,7 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
callback=callback,
)
indexing_pipeline = build_indexing_pipeline(
@@ -157,6 +136,7 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -228,7 +208,9 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -289,7 +271,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress(len(doc_batch))
callback.progress("_run_indexing", len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -419,7 +401,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:

View File

@@ -7,10 +7,10 @@ from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.context.search.models import InferenceSection
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -6,10 +6,10 @@ from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType

View File

@@ -23,6 +23,16 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.context.search.utils import drop_llm_indices
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -56,16 +66,6 @@ from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line

View File

@@ -1,115 +0,0 @@
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
from danswer.prompts.chat_tools import TOOL_TEMPLATE
from danswer.prompts.chat_tools import USER_INPUT
class ToolInfo(TypedDict):
name: str
description: str
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
def call_tool(
model_actions: DanswerChatModelOut,
) -> str:
raise NotImplementedError("There are no additional tool integrations right now")
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -234,7 +234,7 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
@@ -422,6 +422,9 @@ LOG_ALL_MODEL_INTERACTIONS = (
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
LOG_INDIVIDUAL_MODEL_TOKENS = (
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (

View File

@@ -1,9 +1,9 @@
import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
@@ -17,9 +17,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
@@ -27,8 +24,6 @@ DOC_TIME_DECAY = float(
)
BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat

View File

@@ -3,15 +3,13 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from atlassian import Confluence # type: ignore
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@@ -53,6 +51,8 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 1000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -114,25 +114,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials_json=credentials,
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
client_without_retries = Confluence(
api_version="cloud" if self.is_cloud else "latest",
url=self.wiki_base.rstrip("/"),
username=credentials["confluence_username"] if self.is_cloud else None,
password=credentials["confluence_access_token"] if self.is_cloud else None,
token=credentials["confluence_access_token"] if not self.is_cloud else None,
)
spaces = client_without_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(
f"No spaces found at {self.wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -280,6 +265,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
@@ -303,6 +289,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
doc_metadata_list.append(
SlimDocument(

View File

@@ -232,7 +232,6 @@ class OnyxConfluence(Confluence):
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
@@ -241,10 +240,28 @@ class OnyxConfluence(Confluence):
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
yield from self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
@@ -264,6 +281,58 @@ class OnyxConfluence(Confluence):
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)

View File

@@ -269,20 +269,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)

View File

@@ -102,13 +102,21 @@ def _get_tickets(
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
return None
try:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
except requests.exceptions.HTTPError:
# Handle any API errors gracefully
return None
def _article_to_document(

View File

@@ -8,13 +8,13 @@ from pydantic import field_validator
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import SearchType
from danswer.db.models import Persona
from danswer.db.models import SearchSettings
from danswer.indexing.models import BaseChunk
from danswer.indexing.models import IndexingSetting
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.enums import RerankerProvider

View File

@@ -7,6 +7,22 @@ from sqlalchemy.orm import Session
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.context.search.models import SearchRequest
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
from danswer.context.search.postprocessing.postprocessing import search_postprocessing
from danswer.context.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.context.search.retrieval.search_runner import retrieve_chunks
from danswer.context.search.utils import inference_section_from_chunks
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
@@ -16,22 +32,6 @@ from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.search.utils import relevant_sections_to_indices
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -9,19 +9,19 @@ from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import ChunkMetric
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import MAX_METRICS_CONTENT
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import RerankingModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -1,8 +1,8 @@
from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.context.search.models import IndexFilters
from danswer.db.models import User
from danswer.search.models import IndexFilters
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:

View File

@@ -9,21 +9,25 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import SearchQuery
from danswer.context.search.models import SearchRequest
from danswer.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from danswer.context.search.retrieval.search_runner import (
remove_stop_words_and_punctuation,
)
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import QueryAnalysisModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import BaseFilters
from danswer.search.models import IndexFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.utils.logger import setup_logger

View File

@@ -6,6 +6,16 @@ from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.context.search.models import ChunkMetric
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import MAX_METRICS_CONTENT
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
from danswer.context.search.utils import inference_section_from_chunks
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_multilingual_expansion
from danswer.document_index.interfaces import DocumentIndex
@@ -14,16 +24,6 @@ from danswer.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@@ -1,9 +1,9 @@
from typing import cast
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.context.search.models import SavedSearchSettings
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.search.models import SavedSearchSettings
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -2,12 +2,12 @@ from collections.abc import Sequence
from typing import TypeVar
from danswer.chat.models import SectionRelevancePiece
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SavedSearchDoc
from danswer.context.search.models import SavedSearchDocWithContent
from danswer.context.search.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SavedSearchDocWithContent
from danswer.search.models import SearchDoc
T = TypeVar(

View File

@@ -21,6 +21,7 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.context.search.models import SavedSearchDoc
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
@@ -31,7 +32,6 @@ from danswer.danswerbot.slack.icons import source_to_github_img_link
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.search.models import SavedSearchDoc
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space

View File

@@ -21,6 +21,10 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
@@ -48,10 +52,6 @@ from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import DanswerLoggingAdapter

View File

@@ -27,6 +27,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.context.search.retrieval.search_runner import download_nltk_data
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
@@ -75,7 +76,6 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable

View File

@@ -2,6 +2,7 @@ import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
@@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
]
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def get_api_key_fake_email(

View File

@@ -18,6 +18,9 @@ from danswer.auth.schemas import UserRole
from danswer.chat.models import DocumentRelevance
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SavedSearchDoc
from danswer.context.search.models import SearchDoc as ServerSearchDoc
from danswer.db.models import ChatMessage
from danswer.db.models import ChatMessage__SearchDoc
from danswer.db.models import ChatSession
@@ -31,9 +34,6 @@ from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger

View File

@@ -5,6 +5,7 @@ class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
CANCELED = "canceled"
FAILED = "failed"
COMPLETED_WITH_ERRORS = "completed_with_errors"
@@ -12,6 +13,7 @@ class IndexingStatus(str, PyEnum):
terminal_states = {
IndexingStatus.SUCCESS,
IndexingStatus.COMPLETED_WITH_ERRORS,
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
}
return self in terminal_states

View File

@@ -67,6 +67,13 @@ def create_index_attempt(
return new_attempt.id
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt:
db_session.delete(index_attempt)
db_session.commit()
def mock_successful_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,
@@ -218,6 +225,28 @@ def mark_attempt_partially_succeeded(
raise
def mark_attempt_canceled(
index_attempt_id: int,
db_session: Session,
reason: str = "Unknown",
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if not attempt.time_started:
attempt.time_started = datetime.now(timezone.utc)
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_failed(
index_attempt_id: int,
db_session: Session,

View File

@@ -57,7 +57,7 @@ from danswer.utils.special_types import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.context.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
@@ -1181,7 +1181,7 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Models to actually disp;aly to users
# Models to actually display to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True

View File

@@ -20,6 +20,7 @@ from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
@@ -33,7 +34,6 @@ from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.utils.logger import setup_logger
@@ -390,6 +390,9 @@ def upsert_prompt(
return prompt
# NOTE: This operation cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
def upsert_persona(
user: User | None,
name: str,
@@ -458,7 +461,7 @@ def upsert_persona(
validate_persona_tools(tools)
if persona:
if not builtin_persona and persona.builtin_persona:
if persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
@@ -474,7 +477,6 @@ def upsert_persona(
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.builtin_persona = builtin_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
@@ -484,10 +486,8 @@ def upsert_persona(
persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided

View File

@@ -12,6 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.context.search.models import SavedSearchSettings
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
@@ -21,7 +22,6 @@ from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)

View File

@@ -5,6 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
@@ -15,7 +16,6 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.errors import EERequiredError
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,

View File

@@ -103,17 +103,6 @@ def list_users(
return db_session.scalars(stmt).unique().all()
def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
@@ -128,7 +117,7 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return db_session.query(User).filter(User.id == user_id).first() # type: ignore
def _generate_non_web_slack_user(email: str) -> User:
def _generate_slack_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
@@ -149,13 +138,29 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
db_session.commit()
return user
user = _generate_non_web_slack_user(email=email)
user = _generate_slack_user(email=email)
db_session.add(user)
db_session.commit()
return user
def _generate_non_web_permissioned_user(email: str) -> User:
def _get_users_by_emails(
db_session: Session, lower_emails: list[str]
) -> tuple[list[User], list[str]]:
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
# Extract found emails and convert to lowercase to avoid case sensitivity issues
found_users_emails = [user.email.lower() for user in found_users]
# Separate emails for users that were not found
missing_user_emails = [
email for email in lower_emails if email not in found_users_emails
]
return found_users, missing_user_emails
def _generate_ext_permissioned_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
@@ -169,12 +174,12 @@ def _generate_non_web_permissioned_user(email: str) -> User:
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
emails = [email.lower() for email in emails]
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_permissioned_user(email=email))
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))
db_session.add_all(new_users)
db_session.commit()

View File

@@ -3,10 +3,10 @@ import uuid
from sqlalchemy.orm import Session
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.models import IndexChunk
from danswer.search.models import InferenceChunk
DEFAULT_BATCH_SIZE = 30

View File

@@ -4,9 +4,9 @@ from datetime import datetime
from typing import Any
from danswer.access.models import DocumentAccess
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding

View File

@@ -11,6 +11,8 @@ import httpx
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
@@ -44,8 +46,6 @@ from danswer.document_index.vespa_constants import SOURCE_LINKS
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TITLE
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@@ -22,6 +22,8 @@ from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
from danswer.configs.chat_configs import VESPA_SEARCHER_THREADS
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import UpdateRequest
@@ -68,8 +70,6 @@ from danswer.document_index.vespa_constants import VESPA_TIMEOUT
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.key_value_store.factory import get_kv_store
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -3,6 +3,7 @@ from datetime import timedelta
from datetime import timezone
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.context.search.models import IndexFilters
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa_constants import ACCESS_CONTROL_LIST
from danswer.document_index.vespa_constants import CHUNK_ID
@@ -13,7 +14,6 @@ from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import METADATA_LIST
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.search.models import IndexFilters
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -10,10 +10,11 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_text
from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
@@ -125,7 +126,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -134,7 +135,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.callback = callback
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -220,9 +221,20 @@ class Chunker:
mini_chunk_texts=self._get_mini_chunk_texts(text),
)
for section in document.sections:
section_text = section.text
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""
# If there is no useful content, not even the title, just drop it
if not section_text and (not document.title or section_idx > 0):
# If a section is empty and the document has no title, we can just drop it. We return a list of
# DocAwareChunks where each one contains the necessary information needed down the line for indexing.
# There is no concern about dropping whole documents from this list, it should not cause any indexing failures.
logger.warning(
f"Skipping section {section.text} from document "
f"{document.semantic_identifier} due to empty text after cleaning "
f" with link {section_link_text}"
)
continue
section_token_count = len(self.tokenizer.tokenize(section_text))
@@ -238,31 +250,26 @@ class Chunker:
split_texts = self.chunk_splitter.split_text(section_text)
for i, split_text in enumerate(split_texts):
split_token_count = len(self.tokenizer.tokenize(split_text))
if STRICT_CHUNK_TOKEN_LIMIT:
split_token_count = len(self.tokenizer.tokenize(split_text))
if split_token_count > content_token_limit:
# Further split the oversized chunk
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
if (
STRICT_CHUNK_TOKEN_LIMIT
and
# Tokenizer only runs if STRICT_CHUNK_TOKEN_LIMIT is true
len(self.tokenizer.tokenize(split_text)) > content_token_limit
):
# If STRICT_CHUNK_TOKEN_LIMIT is true, manually check
# the token count of each split text to ensure it is
# not larger than the content_token_limit
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=split_text,
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
chunks.append(
_create_chunk(
@@ -354,11 +361,20 @@ class Chunker:
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
"""
Takes in a list of documents and chunks them into smaller chunks for indexing
while persisting the document metadata.
"""
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.callback:
if self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")
if self.heartbeat:
self.heartbeat.heartbeat()
chunks = self._handle_single_document(document)
final_chunks.extend(chunks)
if self.callback:
self.callback.progress("Chunker.chunk", len(chunks))
return final_chunks

View File

@@ -2,7 +2,7 @@ from abc import ABC
from abc import abstractmethod
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -34,7 +34,7 @@ class IndexingEmbedder(ABC):
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None,
callback: IndexingHeartbeatInterface | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -60,7 +60,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
callback=callback,
)
@abstractmethod
@@ -83,7 +83,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None,
callback: IndexingHeartbeatInterface | None = None,
):
super().__init__(
model_name,
@@ -95,7 +95,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url,
api_version,
deployment_name,
heartbeat,
callback,
)
@log_function_time()
@@ -201,7 +201,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
cls,
search_settings: SearchSettings,
callback: IndexingHeartbeatInterface | None = None,
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -213,5 +215,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat,
callback=callback,
)

View File

@@ -1,41 +1,15 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
from abc import ABC
from abc import abstractmethod
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")
@abstractmethod
def progress(self, tag: str, amount: int) -> None:
"""Send progress updates to the caller."""

View File

@@ -34,7 +34,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -414,6 +414,7 @@ def build_indexing_pipeline(
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@@ -440,13 +441,8 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
)
return partial(

View File

@@ -233,6 +233,8 @@ class Answer:
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(

View File

@@ -58,8 +58,8 @@ class AnswerPromptBuilder:
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
raw_user_text: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -89,11 +89,7 @@ class AnswerPromptBuilder:
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = (
HumanMessage(content=raw_user_text)
if raw_user_text is not None
else user_message
)
self.raw_user_message = raw_user_text
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:

View File

@@ -3,6 +3,7 @@ from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
@@ -29,7 +30,6 @@ from danswer.prompts.token_counts import (
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -2,45 +2,15 @@ from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
def _build_weak_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
or distilled models. It only uses one context document and has very weak requirements of
output format.
"""
context_block = ""
if context_docs:
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=prompt.system_prompt,
context_block=context_block,
task_prompt=prompt.task_prompt,
user_query=question,
)
if prompt.datetime_aware:
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
return HumanMessage(content=prompt_str)
def _build_strong_llm_quotes_prompt(
@@ -81,15 +51,9 @@ def build_quotes_user_message(
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
query, _ = message_to_prompt_and_imgs(message)
return prompt_builder(
return _build_strong_llm_quotes_prompt(
question=query,
context_docs=context_docs,
history_str=history_str,

View File

@@ -10,6 +10,8 @@ from danswer.chat.models import (
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
@@ -17,8 +19,6 @@ from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger

View File

@@ -13,6 +13,9 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
class AnswerResponseHandler(abc.ABC):
@@ -48,6 +51,9 @@ class CitationResponseHandler(AnswerResponseHandler):
self.processed_text = ""
self.citations: list[CitationInfo] = []
# TODO remove this after citation issue is resolved
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
def handle_response_part(
self,
response_item: BaseMessage | None,

View File

@@ -12,9 +12,9 @@ from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.context.search.models import InferenceChunk
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from pydantic import BaseModel
from danswer.chat.models import LlmDoc
from danswer.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunk
class DocumentIdOrderMapping(BaseModel):

View File

@@ -62,7 +62,7 @@ class ToolResponseHandler:
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
@@ -76,7 +76,7 @@ class ToolResponseHandler:
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
@@ -95,7 +95,7 @@ class ToolResponseHandler:
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
llm=llm,
)
if available_tools_and_args

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
from danswer.utils.logger import setup_logger
@@ -117,10 +118,19 @@ class LLM(abc.ABC):
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)
tokens = []
for message in messages:
if LOG_INDIVIDUAL_MODEL_TOKENS:
tokens.append(message.content)
yield message
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
logger.debug(f"Model Tokens: {tokens}")
@abc.abstractmethod
def _stream_implementation(
self,

View File

@@ -136,9 +136,11 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def _process_csv_file(file: InMemoryChatFile) -> str:
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string()
csv_preview = df.head().to_string(max_cols=max_columns)
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"

View File

@@ -1,4 +1,3 @@
import re
import threading
import time
from collections.abc import Callable
@@ -16,7 +15,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -50,28 +49,6 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
_INITIAL_FILTER = re.compile(
"["
"\U0000FFF0-\U0000FFFF" # Specials
"\U0001F000-\U0001F9FF" # Emoticons
"\U00002000-\U0000206F" # General Punctuation
"\U00002190-\U000021FF" # Arrows
"\U00002700-\U000027BF" # Dingbats
"]+",
flags=re.UNICODE,
)
def clean_openai_text(text: str) -> str:
# Remove specific Unicode ranges that might cause issues
cleaned = _INITIAL_FILTER.sub("", text)
# Remove any control characters except for newline and tab
cleaned = "".join(ch for ch in cleaned if ch >= " " or ch in "\n\t")
return cleaned
def build_model_server_url(
model_server_host: str,
model_server_port: int,
@@ -99,7 +76,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
callback: IndexingHeartbeatInterface | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
@@ -116,7 +93,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
self.callback = callback
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -160,6 +137,10 @@ class EmbeddingModel:
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
@@ -179,8 +160,8 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings
def encode(
@@ -211,11 +192,6 @@ class EmbeddingModel:
for text in texts
]
if self.provider_type == EmbeddingProvider.OPENAI:
# If the provider is openai, we need to clean the text
# as a temporary workaround for the openai API
texts = [clean_openai_text(text) for text in texts]
batch_size = (
api_embedding_batch_size
if self.provider_type

View File

@@ -7,7 +7,7 @@ from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider

View File

@@ -18,6 +18,11 @@ from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -42,11 +47,6 @@ from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail

View File

@@ -9,12 +9,12 @@ from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
class QueryRephrase(BaseModel):
@@ -36,10 +36,6 @@ class PromptConfig(BaseModel):
datetime_aware: bool = True
class DocumentSetConfig(BaseModel):
id: int
class ToolConfig(BaseModel):
id: int

View File

@@ -118,18 +118,6 @@ You should always get right to the point, and never use extraneous language.
"""
# For weak LLM which only takes one chunk and cannot output json
# Also not requiring quotes as it tends to not work
WEAK_LLM_PROMPT = f"""
{{system_prompt}}
{{context_block}}
{{task_prompt}}
{QUESTION_PAT.upper()}
{{user_query}}
""".strip()
# This is only for visualization for the users to specify their own prompts
# The actual flow does not work like this
PARAMATERIZED_PROMPT = f"""

View File

@@ -7,12 +7,12 @@ from langchain_core.messages import BaseMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.constants import DocumentSource
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Prompt
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger

View File

@@ -17,7 +17,7 @@ from danswer.db.document import construct_document_select_for_connector_credenti
from danswer.db.models import Document as DbDocument
class RedisConnectorDeletionFenceData(BaseModel):
class RedisConnectorDeletePayload(BaseModel):
num_tasks: int | None
submitted: datetime
@@ -54,20 +54,18 @@ class RedisConnectorDelete:
return False
@property
def payload(self) -> RedisConnectorDeletionFenceData | None:
def payload(self) -> RedisConnectorDeletePayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_str)
)
payload = RedisConnectorDeletePayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
def set_fence(self, payload: RedisConnectorDeletePayload | None) -> None:
if not payload:
self.redis.delete(self.fence_key)
return

View File

@@ -1,12 +1,12 @@
import re
from danswer.chat.models import SectionRelevancePiece
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_SYSTEM_PROMPT
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_USER_PROMPT
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -1,9 +1,9 @@
# NOTE No longer used. This needs to be revisited later.
import re
from collections.abc import Iterator
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -46,7 +46,7 @@ def extract_answerability_bool(model_raw: str) -> bool:
def get_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
user_query: str, skip_check: bool = False
) -> tuple[str, bool]:
if skip_check:
return "Query Answerability Evaluation feature is turned off", True
@@ -67,7 +67,7 @@ def get_query_answerability(
def stream_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
user_query: str, skip_check: bool = False
) -> Iterator[str]:
if skip_check:
yield get_json_line(

View File

@@ -5,6 +5,7 @@ from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
@@ -14,7 +15,6 @@ from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
@@ -81,6 +81,7 @@ def load_personas_from_yaml(
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)

View File

@@ -5,6 +5,10 @@ from fastapi import Query
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.context.search.models import IndexFilters
from danswer.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
@@ -12,8 +16,6 @@ from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import IndexFilters
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo
from danswer.server.documents.models import DocumentInfo

View File

@@ -176,6 +176,9 @@ def create_persona(
)
# NOTE: This endpoint cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
@basic_router.patch("/{persona_id}")
def update_persona(
persona_id: int,

View File

@@ -4,10 +4,10 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.models import Persona
from danswer.db.models import PersonaCategory
from danswer.db.models import StarterMessage
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot
from danswer.server.features.tool.models import ToolSnapshot

View File

@@ -6,11 +6,11 @@ from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.llm.factory import get_default_llms
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.server.danswer_api.ingestion import api_key_dep
from danswer.utils.logger import setup_logger

View File

@@ -10,6 +10,7 @@ from pydantic import model_validator
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from danswer.configs.constants import AuthType
from danswer.context.search.models import SavedSearchSettings
from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
@@ -17,7 +18,6 @@ from danswer.db.models import SlackBot as SlackAppModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig as SlackChannelConfigModel
from danswer.db.models import User
from danswer.search.models import SavedSearchSettings
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot

View File

@@ -7,6 +7,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.context.search.models import SavedSearchSettings
from danswer.context.search.models import SearchSettingsCreationRequest
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.engine import get_session
@@ -25,8 +27,6 @@ from danswer.file_processing.unstructured import delete_unstructured_api_key
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import update_unstructured_api_key
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.search.models import SearchSettingsCreationRequest
from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.models import IdReturn

View File

@@ -618,7 +618,6 @@ def update_user_assistant_list(
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
set_no_auth_user_preferences(store, no_auth_user.preferences)

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.engine import get_session
from danswer.db.models import Persona
from danswer.db.models import User
@@ -19,7 +20,6 @@ from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.db.tools import get_tool_by_name
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.constants import MessageType
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
@@ -20,7 +21,6 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import User
from danswer.search.models import RetrievalDetails
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.search.search_tool import SearchTool

View File

@@ -9,15 +9,15 @@ from danswer.chat.models import RetrievalDocs
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchDoc
from danswer.context.search.models import Tag
from danswer.db.enums import ChatSessionSharedStatus
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import BaseFilters
from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchDoc
from danswer.search.models import Tag
from danswer.tools.models import ToolCallFinalResult
@@ -29,10 +29,6 @@ class TagResponse(BaseModel):
tags: list[SourceTag]
class SimpleQueryRequest(BaseModel):
query: str
class UpdateChatSessionThreadRequest(BaseModel):
# If not specified, use Danswer default persona
chat_session_id: UUID
@@ -217,6 +213,7 @@ class ChatSessionDetailResponse(BaseModel):
current_alternate_model: str | None
# This one is not used anymore
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool

View File

@@ -13,6 +13,12 @@ from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import SearchDoc
from danswer.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
@@ -28,19 +34,11 @@ from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.one_shot_answer.answer_question import stream_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.search.models import IndexFilters
from danswer.search.models import SearchDoc
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.query_and_chat.models import AdminSearchRequest
from danswer.server.query_and_chat.models import AdminSearchResponse
from danswer.server.query_and_chat.models import ChatSessionDetails
from danswer.server.query_and_chat.models import ChatSessionsResponse
from danswer.server.query_and_chat.models import QueryValidationResponse
from danswer.server.query_and_chat.models import SearchSessionDetailResponse
from danswer.server.query_and_chat.models import SimpleQueryRequest
from danswer.server.query_and_chat.models import SourceTag
from danswer.server.query_and_chat.models import TagResponse
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
@@ -133,18 +131,6 @@ def get_tags(
return TagResponse(tags=server_tags)
@basic_router.post("/query-validation")
def query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
) -> QueryValidationResponse:
# Note if weak model prompt is chosen, this check does not occur and will simply return that
# the query is valid, this is because weaker models cannot really handle this task well.
# Additionally, some weak model servers cannot handle concurrent inferences.
logger.notice(f"Validating query: {simple_query.query}")
reasoning, answerable = get_query_answerability(simple_query.query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
@basic_router.get("/user-searches")
def get_user_search_sessions(
user: User | None = Depends(current_user),
@@ -245,21 +231,6 @@ def get_search_session(
return response
# NOTE No longer used, after search/chat redesign.
# No search responses are answered with a conversational generative AI response
@basic_router.post("/stream-query-validation")
def stream_query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
) -> StreamingResponse:
# Note if weak model prompt is chosen, this check does not occur and will simply return that
# the query is valid, this is because weaker models cannot really handle this task well.
# Additionally, some weak model servers cannot handle concurrent inferences.
logger.notice(f"Validating query: {simple_query.query}")
return StreamingResponse(
stream_query_answerability(simple_query.query), media_type="application/json"
)
@basic_router.post("/stream-answer-with-quote")
def get_answer_with_quote(
query_request: DirectQARequest,

View File

@@ -2,7 +2,6 @@ import time
from sqlalchemy.orm import Session
from danswer.chat.load_yamls import load_chat_yamls
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import MANAGED_VESPA
from danswer.configs.constants import KV_REINDEX_KEY
@@ -10,6 +9,8 @@ from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.context.search.models import SavedSearchSettings
from danswer.context.search.retrieval.search_runner import download_nltk_data
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
@@ -37,9 +38,8 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.seeding.load_docs import seed_initial_documents
from danswer.seeding.load_yamls import load_chat_yamls
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
@@ -254,13 +254,14 @@ def setup_postgres(db_session: Session) -> None:
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
logger.notice("Loading default Prompts and Personas")
load_chat_yamls(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)

View File

@@ -11,6 +11,9 @@ from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import Persona
from danswer.db.models import User
@@ -22,9 +25,6 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import Tool

View File

@@ -10,6 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.context.search.models import SearchDoc
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
@@ -18,7 +19,6 @@ from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.search.models import SearchDoc
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse

View File

@@ -14,6 +14,14 @@ from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.llm.answering.llm_response_handler import LLMCall
@@ -27,14 +35,6 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
from danswer.llm.answering.prune_and_merge import prune_sections
from danswer.llm.interfaces import LLM
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.message import ToolCallSummary

View File

@@ -1,6 +1,6 @@
from danswer.chat.models import LlmDoc
from danswer.context.search.models import InferenceSection
from danswer.prompts.prompt_utils import clean_up_source
from danswer.search.models import InferenceSection
def llm_doc_to_dict(llm_doc: LlmDoc, doc_num: int) -> dict:

View File

@@ -1,5 +1,7 @@
from typing import cast
from langchain_core.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
@@ -58,9 +60,11 @@ def build_next_prompt_for_search_like_tool(
# For Quotes, the system prompt is included in the user prompt
prompt_builder.update_system_prompt(None)
human_message = HumanMessage(content=prompt_builder.raw_user_message)
prompt_builder.update_user_prompt(
build_quotes_user_message(
message=prompt_builder.raw_user_message,
message=human_message,
context_docs=final_context_documents,
history_str=prompt_builder.single_message_history or "",
prompt=prompt_config,

View File

@@ -126,6 +126,28 @@ def shared_precompare_cleanup(text: str) -> str:
return text
_INITIAL_FILTER = re.compile(
"["
"\U0000FFF0-\U0000FFFF" # Specials
"\U0001F000-\U0001F9FF" # Emoticons
"\U00002000-\U0000206F" # General Punctuation
"\U00002190-\U000021FF" # Arrows
"\U00002700-\U000027BF" # Dingbats
"]+",
flags=re.UNICODE,
)
def clean_text(text: str) -> str:
# Remove specific Unicode ranges that might cause issues
cleaned = _INITIAL_FILTER.sub("", text)
# Remove any control characters except for newline and tab
cleaned = "".join(ch for ch in cleaned if ch >= " " or ch in "\n\t")
return cleaned
def is_valid_email(text: str) -> bool:
"""Can use a library instead if more detailed checks are needed"""
regex = r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"

View File

@@ -2,15 +2,13 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.db.saml import get_saml_account
@@ -28,22 +26,18 @@ def verify_auth_setting() -> None:
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)
if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
return user

View File

@@ -8,6 +8,7 @@ from sqlalchemy import desc
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import UnaryExpression
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
@@ -20,22 +21,22 @@ def fetch_chat_sessions_eagerly_by_time(
end: datetime.datetime,
db_session: Session,
limit: int | None = 500,
initial_id: int | None = None,
initial_time: datetime.datetime | None = None,
) -> list[ChatSession]:
id_order = desc(ChatSession.id) # type: ignore
time_order = desc(ChatSession.time_created) # type: ignore
message_order = asc(ChatMessage.id) # type: ignore
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)
filters: list[ColumnElement | BinaryExpression] = [
ChatSession.time_created.between(start, end)
]
if initial_id:
filters.append(ChatSession.id < initial_id)
if initial_time:
filters.append(ChatSession.time_created > initial_time)
subquery = (
db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters)
.order_by(id_order, time_order)
.order_by(ChatSession.id, time_order)
.distinct(ChatSession.id)
.limit(limit)
.subquery()
@@ -43,7 +44,7 @@ def fetch_chat_sessions_eagerly_by_time(
query = (
db_session.query(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id) # type: ignore
.join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options(
joinedload(ChatSession.user),

View File

@@ -5,11 +5,12 @@ from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.models import SamlAccount
from danswer.db.models import User
def upsert_saml_account(
@@ -44,10 +45,14 @@ def upsert_saml_account(
return saml_acc.expires_at
def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
async def get_saml_account(
cookie: str, async_db_session: AsyncSession
) -> SamlAccount | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
stmt = (
select(SamlAccount)
.join(User, User.id == SamlAccount.user_id) # type: ignore
.options(selectinload(SamlAccount.user)) # Use selectinload for collections
.where(
and_(
SamlAccount.encrypted_cookie == cookie,
@@ -56,10 +61,12 @@ def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
)
)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
result = await async_db_session.execute(stmt)
return result.scalars().unique().one_or_none()
def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
async def expire_saml_account(
saml_account: SamlAccount, async_db_session: AsyncSession
) -> None:
saml_account.expires_at = func.now()
db_session.commit()
await async_db_session.commit()

Some files were not shown because too many files have changed in this diff Show More