mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
67 Commits
eval/split
...
v0.4.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
96b582070b | ||
|
|
4a0a927a64 | ||
|
|
ea9a9cb553 | ||
|
|
38af12ab97 | ||
|
|
1b3154188d | ||
|
|
1f321826ad | ||
|
|
cbfbe4e5d8 | ||
|
|
3aa0e0124b | ||
|
|
f2f60c9cc0 | ||
|
|
6c32821ad4 | ||
|
|
d839595330 | ||
|
|
e422f96dff | ||
|
|
d28f460330 | ||
|
|
8e441d975d | ||
|
|
5c78af1f07 | ||
|
|
e325e063ed | ||
|
|
c81b45300b | ||
|
|
26a1e963d1 | ||
|
|
2a983263c7 | ||
|
|
2a37c95a5e | ||
|
|
c277a74f82 | ||
|
|
e4b31cd0d9 | ||
|
|
a40d2a1e2e | ||
|
|
c9fb99d719 | ||
|
|
a4d71e08aa | ||
|
|
546bfbd24b | ||
|
|
27824d6cc6 | ||
|
|
9d5c4ad634 | ||
|
|
9b32003816 | ||
|
|
8bc4123ed7 | ||
|
|
d58aaf7a59 | ||
|
|
a0056a1b3c | ||
|
|
d2584c773a | ||
|
|
807bef8ada | ||
|
|
5afddacbb2 | ||
|
|
4fb6a88f1e | ||
|
|
7057be6a88 | ||
|
|
91be8e7bfb | ||
|
|
9651ea828b | ||
|
|
6ee74bd0d1 | ||
|
|
48a0d29a5c | ||
|
|
6ff8e6c0ea | ||
|
|
2470c68506 | ||
|
|
866bc803b1 | ||
|
|
9c6084bd0d | ||
|
|
a0b46c60c6 | ||
|
|
4029233df0 | ||
|
|
6c88c0156c | ||
|
|
33332d08f2 | ||
|
|
17005fb705 | ||
|
|
48a7fe80b1 | ||
|
|
1276732409 | ||
|
|
f91b92a898 | ||
|
|
6222f533be | ||
|
|
1b49d17239 | ||
|
|
2f5f19642e | ||
|
|
6db4634871 | ||
|
|
5cfed45cef | ||
|
|
581ffde35a | ||
|
|
6313e6d91d | ||
|
|
c09c94bf32 | ||
|
|
0e8ba111c8 | ||
|
|
2ba24b1734 | ||
|
|
44820b4909 | ||
|
|
eb3e7610fc | ||
|
|
7fbbb174bb | ||
|
|
3854ca11af |
25
.github/pull_request_template.md
vendored
Normal file
25
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
## Description
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
@@ -17,15 +17,11 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add_indexing_start_to_connector
|
||||
|
||||
Revision ID: 08a1eda20fe1
|
||||
Revises: 8a87bd6ec550
|
||||
Create Date: 2024-07-23 11:12:39.462397
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "indexing_start")
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Add icon_color and icon_shape to Persona
|
||||
|
||||
Revision ID: 325975216eb3
|
||||
Revises: 91ffac7e65b3
|
||||
Create Date: 2024-07-24 21:29:31.784562
|
||||
|
||||
"""
|
||||
import random
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "325975216eb3"
|
||||
down_revision = "91ffac7e65b3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
colorOptions = [
|
||||
"#FF6FBF",
|
||||
"#6FB1FF",
|
||||
"#B76FFF",
|
||||
"#FFB56F",
|
||||
"#6FFF8D",
|
||||
"#FF6F6F",
|
||||
"#6FFFFF",
|
||||
]
|
||||
|
||||
|
||||
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
|
||||
def generate_random_shape() -> int:
|
||||
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
|
||||
center_fill = random.choice(center_squares)
|
||||
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
|
||||
random.shuffle(remaining_squares)
|
||||
for i in range(10 - bin(center_fill).count("1")):
|
||||
center_fill |= 1 << remaining_squares[i]
|
||||
return center_fill
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
|
||||
|
||||
persona = table(
|
||||
"persona",
|
||||
column("id", sa.Integer),
|
||||
column("icon_color", sa.String),
|
||||
column("icon_shape", sa.Integer),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
personas = conn.execute(select(persona.c.id))
|
||||
|
||||
for persona_id in personas:
|
||||
random_color = random.choice(colorOptions)
|
||||
random_shape = generate_random_shape()
|
||||
conn.execute(
|
||||
persona.update()
|
||||
.where(persona.c.id == persona_id[0])
|
||||
.values(icon_color=random_color, icon_shape=random_shape)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "uploaded_image_id")
|
||||
op.drop_column("persona", "icon_color")
|
||||
@@ -18,7 +18,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
@@ -29,10 +28,8 @@ def upgrade() -> None:
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Add display_model_names to llm_provider
|
||||
|
||||
Revision ID: 473a1a7ca408
|
||||
Revises: 325975216eb3
|
||||
Create Date: 2024-07-25 14:31:02.002917
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
"bedrock": [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
],
|
||||
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
for provider, models in default_models_by_provider.items():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
|
||||
),
|
||||
{"models": models, "provider": provider},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Add type to credentials
|
||||
|
||||
Revision ID: 4ea2c93919c1
|
||||
Revises: 473a1a7ca408
|
||||
Create Date: 2024-07-18 13:07:13.655895
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new 'source' column to the 'credential' table
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=100), # Use String instead of Enum
|
||||
nullable=True, # Initially allow NULL values
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a temporary table that maps each credential to a single connector source.
|
||||
# This is needed because a credential can be associated with multiple connectors,
|
||||
# but we want to assign a single source to each credential.
|
||||
# We use DISTINCT ON to ensure we only get one row per credential_id.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TEMPORARY TABLE temp_connector_credential AS
|
||||
SELECT DISTINCT ON (cc.credential_id)
|
||||
cc.credential_id,
|
||||
c.source AS connector_source
|
||||
FROM connector_credential_pair cc
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Update the 'source' column in the 'credential' table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE credential cred
|
||||
SET source = COALESCE(
|
||||
(SELECT connector_source
|
||||
FROM temp_connector_credential temp
|
||||
WHERE cred.id = temp.credential_id),
|
||||
'NOT_APPLICABLE'
|
||||
)
|
||||
"""
|
||||
)
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("credential", "source")
|
||||
op.drop_column("credential", "name")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -0,0 +1,94 @@
|
||||
"""associate index attempts with ccpair
|
||||
|
||||
Revision ID: 8a87bd6ec550
|
||||
Revises: 4ea2c93919c1
|
||||
Create Date: 2024-07-22 15:15:52.558451
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new connector_credential_pair_id column
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Create a foreign key constraint to the connector_credential_pair table
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
"connector_credential_pair",
|
||||
["connector_credential_pair_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id = ccp.id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_id = ccp.connector_id AND ia.credential_id = ccp.credential_id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the new connector_credential_pair_id column non-nullable
|
||||
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
|
||||
|
||||
# Drop the old connector_id and credential_id columns
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
# Update the index to use connector_credential_pair_id
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old connector_id and credential_id columns
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_credential_pair_id = ccp.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the old connector_id and credential_id columns non-nullable
|
||||
op.alter_column("index_attempt", "connector_id", nullable=False)
|
||||
op.alter_column("index_attempt", "credential_id", nullable=False)
|
||||
|
||||
# Drop the new connector_credential_pair_id column
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("index_attempt", "connector_credential_pair_id")
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_id", "credential_id", "time_created"],
|
||||
)
|
||||
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""add expiry time
|
||||
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
@@ -16,7 +16,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -29,11 +28,9 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -46,4 +43,3 @@ def downgrade() -> None:
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
@@ -50,8 +52,10 @@ from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -92,12 +96,18 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
@@ -147,7 +157,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
@@ -173,7 +183,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
return await super().oauth_callback( # type: ignore
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -185,6 +195,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: google oauth expires after 1hr. We don't want to force the user to
|
||||
# re-authenticate that frequently, so for now we'll just ignore this for
|
||||
# google oauth users
|
||||
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -227,10 +245,12 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
|
||||
@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
|
||||
@@ -20,7 +20,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -49,19 +49,19 @@ def _get_document_generator(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector.input_type
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
disable_connector(attempt.connector_credential_pair.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
@@ -70,7 +70,10 @@ def _get_document_generator(
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
if (
|
||||
attempt.connector_credential_pair.connector_id is None
|
||||
or attempt.connector_credential_pair.connector_id is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
@@ -127,16 +130,21 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
|
||||
last_successful_index_time = (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
db_connector.indexing_start.timestamp()
|
||||
if index_attempt.from_beginning and db_connector.indexing_start is not None
|
||||
else (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -250,8 +258,8 @@ def _run_indexing(
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
@@ -299,9 +307,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
|
||||
db_session.commit()
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
@@ -324,17 +330,17 @@ def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Running indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Completed indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -16,7 +16,9 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
@@ -24,7 +26,7 @@ from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import Connector
|
||||
@@ -33,7 +35,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -111,8 +113,8 @@ def _mark_run_failed(
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
@@ -131,7 +133,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -144,8 +146,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
@@ -155,31 +156,26 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for model in embedding_models:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=cc_pair.connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
create_index_attempt(cc_pair.id, model.id, db_session)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -271,6 +267,8 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
@@ -288,7 +286,7 @@ def kickoff_indexing_jobs(
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector is None:
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
@@ -297,7 +295,7 @@ def kickoff_indexing_jobs(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.credential is None:
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
@@ -326,16 +324,20 @@ def kickoff_indexing_jobs(
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
@@ -366,7 +368,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
@@ -376,7 +378,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
@@ -43,6 +44,7 @@ def create_chat_chain(
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
|
||||
@@ -88,6 +88,8 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
|
||||
@@ -37,7 +37,8 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
|
||||
- id: 1
|
||||
name: "GPT"
|
||||
@@ -50,6 +51,8 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
|
||||
|
||||
- id: 2
|
||||
@@ -63,3 +66,6 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
# If files are already provided, don't run the search tool
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
return None
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
or DISABLE_LLM_CHOOSE_SEARCH
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -253,7 +262,6 @@ def stream_chat_message_objects(
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -274,7 +282,10 @@ def stream_chat_message_objects(
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id, user=user, db_session=db_session
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
@@ -361,6 +372,14 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -576,11 +595,7 @@ def stream_chat_message_objects(
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
|
||||
@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
@@ -248,6 +248,9 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
|
||||
)
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
|
||||
@@ -44,7 +44,6 @@ QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# For chunking/processing chunks
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
|
||||
@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.gmail.constants import CRED_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -146,6 +147,7 @@ def build_service_account_creds(
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.constants import CRED_KEY
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -118,6 +119,7 @@ def update_credential_access_tokens(
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
@@ -131,6 +133,7 @@ def build_service_account_creds(
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
categories: The categories to include in the index.
|
||||
pages: The pages to include in the index.
|
||||
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
|
||||
connector_name: The name of the connector.
|
||||
language_code: The language code of the wiki.
|
||||
batch_size: The batch size for loading documents.
|
||||
|
||||
@@ -104,7 +103,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
connector_name: str,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -118,10 +116,8 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
self.connector_name = connector_name
|
||||
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
|
||||
|
||||
self.family = family_class_dispatch(hostname, connector_name)()
|
||||
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
|
||||
@@ -210,7 +206,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
if __name__ == "__main__":
|
||||
HOSTNAME = "fallout.fandom.com"
|
||||
test_connector = MediaWikiConnector(
|
||||
connector_name="Fallout",
|
||||
hostname=HOSTNAME,
|
||||
categories=["Fallout:_New_Vegas_factions"],
|
||||
pages=["Fallout: New Vegas"],
|
||||
|
||||
@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str | None:
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
|
||||
@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
|
||||
from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from requests_oauthlib import OAuth2Session # type:ignore
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
|
||||
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.SSLError as e:
|
||||
cause = (
|
||||
e.args[0].reason
|
||||
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
|
||||
else e.args
|
||||
)
|
||||
raise Exception(f"SSL error {str(cause)}")
|
||||
except (requests.RequestException, ValueError):
|
||||
raise Exception(f"Unable to reach {url} - check your internet connection")
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
connector_name: str,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -24,7 +23,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
|
||||
categories=categories,
|
||||
pages=pages,
|
||||
recurse_depth=recurse_depth,
|
||||
connector_name=connector_name,
|
||||
language_code=language_code,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
@@ -85,6 +86,7 @@ def create_connector(
|
||||
input_type=connector_data.input_type,
|
||||
connector_specific_config=connector_data.connector_specific_config,
|
||||
refresh_freq=connector_data.refresh_freq,
|
||||
indexing_start=connector_data.indexing_start,
|
||||
prune_freq=connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ,
|
||||
@@ -191,7 +193,8 @@ def fetch_latest_index_attempt_by_connector(
|
||||
for connector in connectors:
|
||||
latest_index_attempt = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(IndexAttempt.connector_id == connector.id)
|
||||
.join(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.connector_id == connector.id)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
)
|
||||
@@ -207,13 +210,11 @@ def fetch_latest_index_attempts_by_status(
|
||||
) -> list[IndexAttempt]:
|
||||
subquery = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_id,
|
||||
IndexAttempt.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
IndexAttempt.status,
|
||||
func.max(IndexAttempt.time_updated).label("time_updated"),
|
||||
)
|
||||
.group_by(IndexAttempt.connector_id)
|
||||
.group_by(IndexAttempt.credential_id)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.group_by(IndexAttempt.status)
|
||||
.subquery()
|
||||
)
|
||||
@@ -223,12 +224,13 @@ def fetch_latest_index_attempts_by_status(
|
||||
query = db_session.query(IndexAttempt).join(
|
||||
alias,
|
||||
and_(
|
||||
IndexAttempt.connector_id == alias.connector_id,
|
||||
IndexAttempt.credential_id == alias.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== alias.connector_credential_pair_id,
|
||||
IndexAttempt.status == alias.status,
|
||||
IndexAttempt.time_updated == alias.time_updated,
|
||||
),
|
||||
)
|
||||
|
||||
return cast(list[IndexAttempt], query.all())
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -42,6 +43,17 @@ def get_connector_credential_pair(
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_connector_credential_source_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> DocumentSource | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
result = db_session.execute(stmt)
|
||||
cc_pair = result.scalar_one_or_none()
|
||||
return cc_pair.connector.source if cc_pair else None
|
||||
|
||||
|
||||
def get_connector_credential_pair_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
@@ -75,17 +87,23 @@ def get_last_successful_attempt_time(
|
||||
# For Secondary Index we don't keep track of the latest success, so have to calculate it live
|
||||
attempt = (
|
||||
db_session.query(IndexAttempt)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model.id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
)
|
||||
.order_by(IndexAttempt.time_started.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not attempt or not attempt.time_started:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector and connector.indexing_start:
|
||||
return connector.indexing_start.timestamp()
|
||||
return 0.0
|
||||
|
||||
return attempt.time_started.timestamp()
|
||||
@@ -241,6 +259,12 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
|
||||
def fetch_connector_credential_pairs(
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
return db_session.query(ConnectorCredentialPair).all()
|
||||
|
||||
|
||||
def resync_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
@@ -253,10 +277,14 @@ def resync_cc_pair(
|
||||
) -> IndexAttempt | None:
|
||||
query = (
|
||||
db_session.query(IndexAttempt)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
.filter(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,10 +2,13 @@ from typing import Any
|
||||
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import and_
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
@@ -14,8 +17,10 @@ from danswer.connectors.google_drive.constants import (
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -74,6 +79,69 @@ def fetch_credential_by_id(
|
||||
return credential
|
||||
|
||||
|
||||
def fetch_credentials_by_source(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
document_source: DocumentSource | None = None,
|
||||
) -> list[Credential]:
|
||||
base_query = select(Credential).where(Credential.source == document_source)
|
||||
base_query = _attach_user_filters(base_query, user)
|
||||
credentials = db_session.execute(base_query).scalars().all()
|
||||
return list(credentials)
|
||||
|
||||
|
||||
def swap_credentials_connector(
|
||||
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
|
||||
) -> ConnectorCredentialPair:
|
||||
# Check if the user has permission to use the new credential
|
||||
new_credential = fetch_credential_by_id(new_credential_id, user, db_session)
|
||||
if not new_credential:
|
||||
raise ValueError(
|
||||
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
|
||||
)
|
||||
|
||||
# Existing pair
|
||||
existing_pair = db_session.execute(
|
||||
select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.connector_id == connector_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not existing_pair:
|
||||
raise ValueError(
|
||||
f"No ConnectorCredentialPair found for connector_id {connector_id}"
|
||||
)
|
||||
|
||||
# Check if the new credential is compatible with the connector
|
||||
if new_credential.source != existing_pair.connector.source:
|
||||
raise ValueError(
|
||||
f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}"
|
||||
)
|
||||
|
||||
db_session.execute(
|
||||
update(DocumentByConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== existing_pair.credential_id,
|
||||
)
|
||||
)
|
||||
.values(credential_id=new_credential_id)
|
||||
)
|
||||
|
||||
# Update the existing pair with the new credential
|
||||
existing_pair.credential_id = new_credential_id
|
||||
existing_pair.credential = new_credential
|
||||
|
||||
# Commit the changes
|
||||
db_session.commit()
|
||||
|
||||
# Refresh the object to ensure all relationships are up-to-date
|
||||
db_session.refresh(existing_pair)
|
||||
return existing_pair
|
||||
|
||||
|
||||
def create_credential(
|
||||
credential_data: CredentialBase,
|
||||
user: User | None,
|
||||
@@ -83,6 +151,8 @@ def create_credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
admin_public=credential_data.admin_public,
|
||||
source=credential_data.source,
|
||||
name=credential_data.name,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
@@ -90,6 +160,30 @@ def create_credential(
|
||||
return credential
|
||||
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.name = credential_data.name
|
||||
credential.name = credential_data.name
|
||||
|
||||
# Update only the keys present in credential_data.credential_json
|
||||
for key, value in credential_data.credential_json.items():
|
||||
credential.credential_json[key] = value
|
||||
|
||||
credential.user_id = user.id if user is not None else None
|
||||
|
||||
db_session.commit()
|
||||
return credential
|
||||
|
||||
|
||||
def update_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialBase,
|
||||
@@ -136,6 +230,7 @@ def delete_credential(
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
@@ -149,11 +244,38 @@ def delete_credential(
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
|
||||
"Please delete all associated connectors first."
|
||||
)
|
||||
associated_doc_cc_pairs = (
|
||||
db_session.query(DocumentByConnectorCredentialPair)
|
||||
.filter(DocumentByConnectorCredentialPair.credential_id == credential_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors or associated_doc_cc_pairs:
|
||||
if force:
|
||||
logger.warning(
|
||||
f"Force deleting credential {credential_id} and its associated records"
|
||||
)
|
||||
|
||||
# Delete DocumentByConnectorCredentialPair records first
|
||||
for doc_cc_pair in associated_doc_cc_pairs:
|
||||
db_session.delete(doc_cc_pair)
|
||||
|
||||
# Then delete ConnectorCredentialPair records
|
||||
for connector in associated_connectors:
|
||||
db_session.delete(connector)
|
||||
|
||||
# Commit these deletions before deleting the credential
|
||||
db_session.flush()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential as it is still associated with "
|
||||
f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). "
|
||||
)
|
||||
|
||||
if force:
|
||||
logger.info(f"Force deleting credential {credential_id}")
|
||||
else:
|
||||
logger.info(f"Deleting credential {credential_id}")
|
||||
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPair
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
@@ -23,6 +24,22 @@ from danswer.utils.telemetry import RecordType
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_last_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
@@ -31,15 +48,13 @@ def get_index_attempt(
|
||||
|
||||
|
||||
def create_index_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
connector_credential_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
db_session: Session,
|
||||
from_beginning: bool = False,
|
||||
) -> int:
|
||||
new_attempt = IndexAttempt(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
from_beginning=from_beginning,
|
||||
status=IndexingStatus.NOT_STARTED,
|
||||
@@ -56,7 +71,9 @@ def get_inprogress_index_attempts(
|
||||
) -> list[IndexAttempt]:
|
||||
stmt = select(IndexAttempt)
|
||||
if connector_id is not None:
|
||||
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair.has(connector_id=connector_id)
|
||||
)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.IN_PROGRESS)
|
||||
|
||||
incomplete_attempts = db_session.scalars(stmt)
|
||||
@@ -65,21 +82,31 @@ def get_inprogress_index_attempts(
|
||||
|
||||
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
|
||||
"""This eagerly loads the connector and credential so that the db_session can be expired
|
||||
before running long-living indexing jobs, which causes increasing memory usage"""
|
||||
before running long-living indexing jobs, which causes increasing memory usage.
|
||||
|
||||
Results are ordered by time_created (oldest to newest)."""
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
stmt = stmt.order_by(IndexAttempt.time_created)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.credential
|
||||
),
|
||||
)
|
||||
new_attempts = db_session.scalars(stmt)
|
||||
return list(new_attempts.all())
|
||||
|
||||
|
||||
def mark_attempt_in_progress__no_commit(
|
||||
def mark_attempt_in_progress(
|
||||
index_attempt: IndexAttempt,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
index_attempt.status = IndexingStatus.IN_PROGRESS
|
||||
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_attempt_succeeded(
|
||||
@@ -103,7 +130,7 @@ def mark_attempt_failed(
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
source = index_attempt.connector.source
|
||||
source = index_attempt.connector_credential_pair.connector.source
|
||||
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
|
||||
|
||||
|
||||
@@ -128,11 +155,16 @@ def get_last_attempt(
|
||||
embedding_model_id: int | None,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Note, the below is using time_created instead of time_updated
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
|
||||
@@ -145,8 +177,7 @@ def get_latest_index_attempts(
|
||||
db_session: Session,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_id,
|
||||
IndexAttempt.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.time_created).label("max_time_created"),
|
||||
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
|
||||
@@ -158,43 +189,95 @@ def get_latest_index_attempts(
|
||||
where_stmts: list[ColumnElement] = []
|
||||
for connector_credential_pair_identifier in connector_credential_pair_identifiers:
|
||||
where_stmts.append(
|
||||
and_(
|
||||
IndexAttempt.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
IndexAttempt.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== (
|
||||
select(ConnectorCredentialPair.id)
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
if where_stmts:
|
||||
ids_stmt = ids_stmt.where(or_(*where_stmts))
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_id, IndexAttempt.credential_id)
|
||||
ids_subqery = ids_stmt.subquery()
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(
|
||||
ids_subqery,
|
||||
and_(
|
||||
ids_subqery.c.connector_id == IndexAttempt.connector_id,
|
||||
ids_subqery.c.credential_id == IndexAttempt.credential_id,
|
||||
),
|
||||
ids_subquery,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== ids_subquery.c.connector_credential_pair_id,
|
||||
)
|
||||
.where(IndexAttempt.time_created == ids_subqery.c.max_time_created)
|
||||
.where(IndexAttempt.time_created == ids_subquery.c.max_time_created)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_index_attempts_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
)
|
||||
if disinclude_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
)
|
||||
)
|
||||
if only_current:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(IndexAttempt.time_created.desc())
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_latest_finished_index_attempt_for_cc_pair(
|
||||
connector_credential_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
.order_by(desc(IndexAttempt.time_created))
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_index_attempts_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_identifier: ConnectorCredentialPairIdentifier,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
stmt = select(IndexAttempt).where(
|
||||
and_(
|
||||
IndexAttempt.connector_id == cc_pair_identifier.connector_id,
|
||||
IndexAttempt.credential_id == cc_pair_identifier.credential_id,
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id == cc_pair_identifier.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== cc_pair_identifier.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
if disinclude_finished:
|
||||
@@ -218,9 +301,11 @@ def delete_index_attempts(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
@@ -254,9 +339,11 @@ def cancel_indexing_attempts_for_connector(
|
||||
db_session: Session,
|
||||
include_secondary_index: bool = False,
|
||||
) -> None:
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.status == IndexingStatus.NOT_STARTED,
|
||||
stmt = (
|
||||
delete(IndexAttempt)
|
||||
.where(IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id)
|
||||
.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
@@ -296,7 +383,7 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally,
|
||||
do a count to get the total number of unique cc-pairs with successful attempts"""
|
||||
unique_pairs_count = (
|
||||
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
|
||||
db_session.query(IndexAttempt.connector_credential_pair)
|
||||
.filter(
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
|
||||
@@ -1,15 +1,41 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Delete existing relationships
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# Add new relationships from given group_ids
|
||||
if group_ids:
|
||||
new_relationships = [
|
||||
LLMProvider__UserGroup(
|
||||
llm_provider_id=llm_provider_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
for group_id in group_ids
|
||||
]
|
||||
db_session.add_all(new_relationships)
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
@@ -36,36 +62,36 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
existing_llm_provider.is_public = llm_provider.is_public
|
||||
existing_llm_provider.display_model_names = llm_provider.display_model_names
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
@@ -74,8 +100,29 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_subquery = (
|
||||
select(User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.subquery()
|
||||
)
|
||||
access_conditions = or_(
|
||||
LLMProviderModel.is_public,
|
||||
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
stmt = stmt.where(access_conditions)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
@@ -119,6 +166,13 @@ def remove_embedding_provider(
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from uuid import UUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
@@ -120,6 +121,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.ARRAY(Integer), nullable=True
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
@@ -337,6 +342,9 @@ class ConnectorCredentialPair(Base):
|
||||
back_populates="connector_credential_pairs",
|
||||
overlaps="document_set",
|
||||
)
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="connector_credential_pair"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
@@ -416,6 +424,9 @@ class Connector(Base):
|
||||
connector_specific_config: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB()
|
||||
)
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -434,14 +445,17 @@ class Connector(Base):
|
||||
documents_by_connector: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="connector"
|
||||
)
|
||||
|
||||
|
||||
class Credential(Base):
|
||||
__tablename__ = "credential"
|
||||
|
||||
name: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
source: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
@@ -462,9 +476,7 @@ class Credential(Base):
|
||||
documents_by_credential: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="credential"
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
@@ -534,13 +546,10 @@ class IndexAttempt(Base):
|
||||
__tablename__ = "index_attempt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
connector_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector.id"),
|
||||
nullable=True,
|
||||
)
|
||||
credential_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
@@ -578,12 +587,10 @@ class IndexAttempt(Base):
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="index_attempts"
|
||||
)
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="index_attempts"
|
||||
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
embedding_model: Mapped[EmbeddingModel] = relationship(
|
||||
"EmbeddingModel", back_populates="index_attempts"
|
||||
)
|
||||
@@ -591,8 +598,7 @@ class IndexAttempt(Base):
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
"connector_credential_pair_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
@@ -600,7 +606,6 @@ class IndexAttempt(Base):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
f"connector_id={self.connector_id!r}, "
|
||||
f"status={self.status!r}, "
|
||||
f"error_msg={self.error_msg!r})>"
|
||||
f"time_created={self.time_created!r}, "
|
||||
@@ -821,6 +826,8 @@ class ChatMessage(Base):
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
@@ -923,6 +930,11 @@ class LLMProvider(Base):
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Models to actually disp;aly to users
|
||||
# If nulled out, we assume in the application logic we should present all
|
||||
display_model_names: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# The LLMs that are available for this provider. Only required if not a default provider.
|
||||
# If a default provider, then the LLM options are pulled from the `options.py` file.
|
||||
# If needed, can be pulled out as a separate table in the future.
|
||||
@@ -932,6 +944,13 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
@@ -1109,7 +1128,10 @@ class Persona(Base):
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
icon_color: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
icon_shape: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
@@ -1137,6 +1159,7 @@ class Persona(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="persona__user_group",
|
||||
@@ -1360,6 +1383,17 @@ class Persona__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider__UserGroup(Base):
|
||||
__tablename__ = "llm_provider__user_group"
|
||||
|
||||
llm_provider_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("llm_provider.id"), primary_key=True
|
||||
)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__UserGroup(Base):
|
||||
__tablename__ = "document_set__user_group"
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import not_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
@@ -24,6 +25,7 @@ from danswer.db.models import StarterMessage
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -80,6 +82,9 @@ def create_update_persona(
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
is_public=create_persona_request.is_public,
|
||||
db_session=db_session,
|
||||
icon_color=create_persona_request.icon_color,
|
||||
icon_shape=create_persona_request.icon_shape,
|
||||
uploaded_image_id=create_persona_request.uploaded_image_id,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -328,6 +333,9 @@ def upsert_persona(
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
commit: bool = True,
|
||||
icon_color: str | None = None,
|
||||
icon_shape: int | None = None,
|
||||
uploaded_image_id: str | None = None,
|
||||
) -> Persona:
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
@@ -383,6 +391,9 @@ def upsert_persona(
|
||||
persona.starter_messages = starter_messages
|
||||
persona.deleted = False # Un-delete if previously deleted
|
||||
persona.is_public = is_public
|
||||
persona.icon_color = icon_color
|
||||
persona.icon_shape = icon_shape
|
||||
persona.uploaded_image_id = uploaded_image_id
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -415,6 +426,9 @@ def upsert_persona(
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
starter_messages=starter_messages,
|
||||
tools=tools or [],
|
||||
icon_shape=icon_shape,
|
||||
icon_color=icon_color,
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
@@ -548,6 +562,8 @@ def get_default_prompt__read_only() -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
|
||||
# a direct mapping indicating whether a user has access to a specific persona?
|
||||
def get_persona_by_id(
|
||||
persona_id: int,
|
||||
# if user is `None` assume the user is an admin or auth is disabled
|
||||
@@ -556,16 +572,38 @@ def get_persona_by_id(
|
||||
include_deleted: bool = False,
|
||||
is_for_edit: bool = True, # NOTE: assume true for safety
|
||||
) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.options(selectinload(Persona.users), selectinload(Persona.groups))
|
||||
.where(Persona.id == persona_id)
|
||||
)
|
||||
|
||||
or_conditions = []
|
||||
|
||||
# if user is an admin, they should have access to all Personas
|
||||
# and will skip the following clause
|
||||
if user is not None and user.role != UserRole.ADMIN:
|
||||
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
|
||||
# the user is not an admin
|
||||
isPersonaUnowned = Persona.user_id.is_(
|
||||
None
|
||||
) # allow access if persona user id is None
|
||||
isUserCreator = (
|
||||
Persona.user_id == user.id
|
||||
) # allow access if user created the persona
|
||||
or_conditions.extend([isPersonaUnowned, isUserCreator])
|
||||
|
||||
# if we aren't editing, also give access to all public personas
|
||||
# if we aren't editing, also give access if:
|
||||
# 1. the user is authorized for this persona
|
||||
# 2. the user is in an authorized group for this persona
|
||||
# 3. if the persona is public
|
||||
if not is_for_edit:
|
||||
isSharedWithUser = Persona.users.any(
|
||||
id=user.id
|
||||
) # allow access if user is in allowed users
|
||||
isSharedWithGroup = Persona.groups.any(
|
||||
UserGroup.users.any(id=user.id)
|
||||
) # allow access if user is in any allowed group
|
||||
or_conditions.extend([isSharedWithUser, isSharedWithGroup])
|
||||
or_conditions.append(Persona.is_public.is_(True))
|
||||
|
||||
if or_conditions:
|
||||
|
||||
@@ -331,12 +331,18 @@ def _index_vespa_chunk(
|
||||
document = chunk.source_document
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
if chunk.embeddings.full_embedding is None:
|
||||
embeddings.full_embedding = chunk.title_embedding
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
if m_c_embed is None:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
|
||||
else:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
|
||||
@@ -346,11 +352,15 @@ def _index_vespa_chunk(
|
||||
BLURB: remove_invalid_unicode_chars(chunk.blurb),
|
||||
TITLE: remove_invalid_unicode_chars(title) if title else None,
|
||||
SKIP_TITLE_EMBEDDING: not title,
|
||||
CONTENT: remove_invalid_unicode_chars(chunk.content),
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
# which contains the title prefix and metadata suffix
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
@@ -358,7 +368,7 @@ def _index_vespa_chunk(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix,
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
|
||||
@@ -103,6 +103,8 @@ def port_api_key_to_postgres() -> None:
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
display_model_names=[],
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -6,7 +6,6 @@ from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -15,12 +14,12 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
||||
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
@@ -28,6 +27,8 @@ if TYPE_CHECKING:
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
|
||||
# could be another 128 tokens leaving 256 for the actual contents
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
@@ -36,14 +37,7 @@ logger = setup_logger()
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_size: int) -> str:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
def extract_blurb(text: str, blurb_splitter: "SentenceSplitter") -> str:
|
||||
return blurb_splitter.split_text(text)[0]
|
||||
|
||||
|
||||
@@ -52,33 +46,25 @@ def chunk_large_section(
|
||||
section_link_text: str,
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
tokenizer: "AutoTokenizer",
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
blurb: str,
|
||||
chunk_splitter: "SentenceSplitter",
|
||||
title_prefix: str = "",
|
||||
metadata_suffix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
blurb = extract_blurb(section_text, blurb_size)
|
||||
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
split_texts = sentence_aware_splitter.split_text(section_text)
|
||||
split_texts = chunk_splitter.split_text(section_text)
|
||||
|
||||
chunks = [
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
|
||||
content_summary=chunk_str,
|
||||
content=chunk_str,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
]
|
||||
@@ -86,42 +72,87 @@ def chunk_large_section(
|
||||
|
||||
|
||||
def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]]
|
||||
) -> str:
|
||||
metadata: dict[str, str | list[str]], include_separator: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
|
||||
and a string of all of the values for the keyword search
|
||||
|
||||
For example, if we have the following metadata:
|
||||
{
|
||||
"author": "John Doe",
|
||||
"space": "Engineering"
|
||||
}
|
||||
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
|
||||
and Engineering. The keys are repeat and much more noisy.
|
||||
"""
|
||||
if not metadata:
|
||||
return ""
|
||||
return "", ""
|
||||
|
||||
metadata_str = "Metadata:\n"
|
||||
metadata_values = []
|
||||
for key, value in metadata.items():
|
||||
if key in get_metadata_keys_to_ignore():
|
||||
continue
|
||||
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
|
||||
if isinstance(value, list):
|
||||
metadata_values.extend(value)
|
||||
else:
|
||||
metadata_values.append(value)
|
||||
|
||||
metadata_str += f"\t{key} - {value_str}\n"
|
||||
return metadata_str.strip()
|
||||
|
||||
metadata_semantic = metadata_str.strip()
|
||||
metadata_keyword = " ".join(metadata_values)
|
||||
|
||||
if include_separator:
|
||||
return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
blurb_size: int = BLURB_SIZE, # Used for both title and content
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
tokenizer = get_default_tokenizer()
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=chunk_tok_size,
|
||||
chunk_overlap=subsection_overlap,
|
||||
)
|
||||
|
||||
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_suffix_keyword = ""
|
||||
metadata_tokens = 0
|
||||
if include_metadata:
|
||||
metadata = _get_metadata_suffix_for_document_index(document.metadata)
|
||||
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
|
||||
(
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
) = _get_metadata_suffix_for_document_index(
|
||||
document.metadata, include_separator=True
|
||||
)
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix = ""
|
||||
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
|
||||
# context, there is no limit for the keyword component
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
|
||||
@@ -130,7 +161,7 @@ def chunk_document(
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
content_token_limit = chunk_tok_size
|
||||
title_prefix = ""
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
@@ -151,12 +182,13 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@@ -167,12 +199,11 @@ def chunk_document(
|
||||
section_link_text=section_link_text,
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=content_token_limit,
|
||||
chunk_overlap=subsection_overlap,
|
||||
blurb_size=blurb_size,
|
||||
chunk_splitter=chunk_splitter,
|
||||
blurb=extract_blurb(section_text, blurb_splitter),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix=metadata_suffix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunks.extend(large_section_chunks)
|
||||
continue
|
||||
@@ -193,12 +224,13 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
@@ -211,12 +243,13 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
)
|
||||
return chunks
|
||||
|
||||
@@ -4,7 +4,6 @@ from abc import abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
@@ -14,8 +13,7 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@@ -66,48 +64,38 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
)
|
||||
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
title_embed_dict: dict[str, list[float] | None] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_texts: list[str] = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
# The whole chunk including the prefix/suffix is included in the overall vector representation
|
||||
chunk_texts.append(
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
)
|
||||
mini_chunk_texts = (
|
||||
split_chunk_text_into_mini_chunks(chunk.content_summary)
|
||||
split_chunk_text_into_mini_chunks(chunk.content)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
embeddings = self.embedding_model.encode(
|
||||
chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
@@ -116,16 +104,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# Drop any None or empty strings
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
@@ -184,4 +171,6 @@ def get_embedding_model_from_db_embedding_model(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
)
|
||||
|
||||
@@ -124,6 +124,19 @@ def index_doc_batch(
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
# Skip documents that have neither title nor content
|
||||
documents_to_process = []
|
||||
for document in documents:
|
||||
if not document.title and not any(
|
||||
section.text.strip() for section in document.sections
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content"
|
||||
)
|
||||
else:
|
||||
documents_to_process.append(document)
|
||||
documents = documents_to_process
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
Embedding = list[float] | None
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
@@ -36,15 +36,13 @@ class DocAwareChunk(BaseChunk):
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
|
||||
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
|
||||
# it's easier to just store a not prefixed/suffixed string for the highlighting
|
||||
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
|
||||
content_summary: str
|
||||
title_prefix: str
|
||||
|
||||
# During indexing we also (optionally) build a metadata string from the metadata dict
|
||||
# This is also indexed so that we can strip it out after indexing, this way it supports
|
||||
# multiple iterations of metadata representation for backwards compatibility
|
||||
metadata_suffix: str
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
|
||||
@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
@@ -99,6 +99,7 @@ class Answer:
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@@ -107,10 +108,8 @@ class Answer:
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# if specified, tells the LLM to always this tool
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
@@ -129,6 +128,7 @@ class Answer:
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
@@ -187,7 +187,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool and self.force_use_tool.args is not None:
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
@@ -221,7 +221,7 @@ class Answer:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -240,12 +240,26 @@ class Answer:
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool = [
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
][0]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool and self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
@@ -286,7 +300,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool:
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
@@ -303,7 +317,7 @@ class Answer:
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -32,6 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -49,6 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
|
||||
@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
|
||||
@@ -266,7 +266,9 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_output_tokens,
|
||||
max_tokens=self._max_output_tokens
|
||||
if self._max_output_tokens > 0
|
||||
else None,
|
||||
timeout=self._timeout,
|
||||
**self._model_kwargs,
|
||||
)
|
||||
|
||||
@@ -70,6 +70,8 @@ def load_llm_providers(db_session: Session) -> None:
|
||||
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
||||
),
|
||||
model_names=model_names,
|
||||
is_public=True,
|
||||
display_model_names=[],
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -31,9 +31,7 @@ OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
# "gpt-4-32k", # not EOL but still doesnt work
|
||||
"gpt-4-0613",
|
||||
# "gpt-4-32k-0613", # not EOL but still doesnt work
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-3.5-turbo",
|
||||
@@ -48,9 +46,11 @@ OPEN_AI_MODEL_NAMES = [
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [model for model in litellm.bedrock_models if "/" not in model][
|
||||
::-1
|
||||
]
|
||||
BEDROCK_MODEL_NAMES = [
|
||||
model
|
||||
for model in litellm.bedrock_models
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
@@ -84,7 +84,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(OPENAI_PROVIDER_NAME),
|
||||
default_model="gpt-4",
|
||||
default_fast_model="gpt-3.5-turbo",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -16,10 +15,8 @@ from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
@@ -28,7 +25,6 @@ from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
@@ -37,60 +33,17 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now
|
||||
files = [] if isinstance(msg, ChatMessage) else msg.files
|
||||
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# be attached to AI messages, so we must remove them
|
||||
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
|
||||
@@ -50,8 +50,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
from httpx import HTTPError
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.natural_language_processing.utils import get_default_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -20,50 +19,13 @@ from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
|
||||
|
||||
def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str,
|
||||
model_server_port: int,
|
||||
@@ -91,6 +53,7 @@ class EmbeddingModel:
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
retrim_content: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -99,32 +62,90 @@ class EmbeddingModel:
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
self.retrim_content = retrim_content
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float] | None]:
|
||||
if not texts:
|
||||
logger.warning("No texts to be embedded")
|
||||
return []
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
if self.retrim_content:
|
||||
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
|
||||
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
|
||||
# However this slight miscounting is very unlikely to have any material impact.
|
||||
texts = [
|
||||
tokenizer_trim_content(
|
||||
content=text,
|
||||
desired_length=self.max_seq_length,
|
||||
tokenizer=get_default_tokenizer(),
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
if self.provider_type:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
EmbedResponse(**response.json()).embeddings
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float] | None] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
@@ -136,7 +157,7 @@ class CrossEncoderEnsembleModel:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
response = requests.post(
|
||||
@@ -199,7 +220,7 @@ def warm_up_encoders(
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
# retry the whole server
|
||||
wait_time = 5
|
||||
for attempt in range(20):
|
||||
for _ in range(20):
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
return
|
||||
100
backend/danswer/natural_language_processing/utils.py
Normal file
100
backend/danswer/natural_language_processing/utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.core import Encoding
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
@@ -129,7 +129,7 @@ def stream_answer_objects(
|
||||
messages=history, max_tokens=max_history_tokens
|
||||
)
|
||||
|
||||
rephrased_query = thread_based_query_rephrase(
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
@@ -206,6 +206,7 @@ def stream_answer_objects(
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
),
|
||||
@@ -256,6 +257,9 @@ def stream_answer_objects(
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
@@ -268,9 +272,6 @@ def stream_answer_objects(
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SEARCH_EVALUATION_ID:
|
||||
evaluation_response = LLMRelevanceSummaryResponse(
|
||||
relevance_summaries=packet.response
|
||||
|
||||
@@ -34,10 +34,17 @@ class DirectQARequest(ChunkContext):
|
||||
skip_llm_chunk_filter: bool | None = None
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# This is to toggle agentic evaluation:
|
||||
# 1. Evaluates whether each response is relevant or not
|
||||
# 2. Provides a summary of the document's relevance in the resulsts
|
||||
llm_doc_eval: bool = False
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -154,6 +155,7 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
||||
@@ -173,9 +175,11 @@ class SearchPipeline:
|
||||
expanded_inference_sections = []
|
||||
|
||||
# Full doc setting takes priority
|
||||
|
||||
if self.search_query.full_doc:
|
||||
seen_document_ids = set()
|
||||
unique_chunks = []
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
@@ -195,7 +199,6 @@ class SearchPipeline:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
@@ -240,32 +243,35 @@ class SearchPipeline:
|
||||
merged_ranges = [
|
||||
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
||||
]
|
||||
flat_ranges = [r for ranges in merged_ranges for r in ranges]
|
||||
|
||||
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
|
||||
flattened_inference_chunks: list[InferenceChunk] = []
|
||||
parallel_functions_with_args = []
|
||||
|
||||
for chunk_range in flat_ranges:
|
||||
functions_with_args.append(
|
||||
(
|
||||
# If Large Chunks are introduced, additional filters need to be added here
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
# Only need the document_id here, just use any chunk in the range is fine
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
# There is no chunk level permissioning, this expansion around chunks
|
||||
# can be assumed to be safe
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
|
||||
if above == below == 0:
|
||||
flattened_inference_chunks.extend(chunk_range.chunks)
|
||||
|
||||
# list of list of inference chunks where the inner list needs to be combined for content
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
flattened_inference_chunks = [
|
||||
chunk for sublist in list_inference_chunks for chunk in sublist
|
||||
]
|
||||
else:
|
||||
parallel_functions_with_args.append(
|
||||
(
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if parallel_functions_with_args:
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
parallel_functions_with_args, allow_failures=False
|
||||
)
|
||||
for inference_chunks in list_inference_chunks:
|
||||
flattened_inference_chunks.extend(inference_chunks)
|
||||
|
||||
doc_chunk_ind_to_chunk = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
import numpy
|
||||
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
CrossEncoderEnsembleModel,
|
||||
)
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -58,8 +60,14 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
|
||||
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
|
||||
# BLURB SIZE is by token instead of char but each token is at least 1 char
|
||||
# If this prefix matches the content, it's assumed the title was prepended
|
||||
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
|
||||
return (
|
||||
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
|
||||
if RETURN_SEPARATOR in chunk.content
|
||||
else chunk.content
|
||||
)
|
||||
|
||||
return chunk.content
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.natural_language_processing.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.search_nlp_models import IntentModel
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.search.search_nlp_models import IntentModel
|
||||
from danswer.server.query_and_chat.models import HelperResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@@ -11,6 +12,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
@@ -20,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -143,7 +144,9 @@ def doc_index_retrieval(
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
@@ -152,7 +155,9 @@ def doc_index_retrieval(
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
|
||||
@@ -94,7 +94,7 @@ def history_based_query_rephrase(
|
||||
llm: LLM,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
skip_first_rephrase: bool = True,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> str:
|
||||
# Globally disabled, just use the exact user query
|
||||
|
||||
@@ -96,6 +96,8 @@ def upsert_ingestion_doc(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -132,6 +134,8 @@ def upsert_ingestion_doc(
|
||||
normalize=sec_db_embedding_model.normalize,
|
||||
query_prefix=sec_db_embedding_model.query_prefix,
|
||||
passage_prefix=sec_db_embedding_model.passage_prefix,
|
||||
api_key=sec_db_embedding_model.api_key,
|
||||
provider_type=sec_db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
|
||||
@@ -12,7 +12,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
from danswer.db.document import get_document_cnts_for_cc_pairs
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_index_attempts_for_connector
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -43,9 +43,9 @@ def get_cc_pair_full_info(
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
index_attempts = get_index_attempts_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_identifier=cc_pair_identifier,
|
||||
index_attempts = get_index_attempts_for_connector(
|
||||
db_session,
|
||||
cc_pair.connector_id,
|
||||
)
|
||||
|
||||
document_count_info_list = list(
|
||||
|
||||
@@ -51,6 +51,8 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector import get_connector_credential_ids
|
||||
from danswer.db.connector import update_connector
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.credentials import delete_gmail_service_account_credentials
|
||||
@@ -64,6 +66,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
@@ -74,6 +77,7 @@ from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import ConnectorSnapshot
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.documents.models import FileUploadResponse
|
||||
from danswer.server.documents.models import GDriveCallback
|
||||
@@ -263,7 +267,8 @@ def upsert_service_account_credential(
|
||||
`Credential` table."""
|
||||
try:
|
||||
credential_base = build_service_account_creds(
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
|
||||
)
|
||||
except ConfigNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -288,7 +293,8 @@ def upsert_gmail_service_account_credential(
|
||||
`Credential` table."""
|
||||
try:
|
||||
credential_base = build_service_account_creds(
|
||||
delegated_user_email=service_account_credential_request.gmail_delegated_user
|
||||
DocumentSource.GMAIL,
|
||||
delegated_user_email=service_account_credential_request.gmail_delegated_user,
|
||||
)
|
||||
except ConfigNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -386,8 +392,12 @@ def get_connector_indexing_status(
|
||||
secondary_index=secondary_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
cc_pair_to_latest_index_attempt = {
|
||||
(index_attempt.connector_id, index_attempt.credential_id): index_attempt
|
||||
(
|
||||
index_attempt.connector_credential_pair.connector_id,
|
||||
index_attempt.connector_credential_pair.credential_id,
|
||||
): index_attempt
|
||||
for index_attempt in latest_index_attempts
|
||||
}
|
||||
|
||||
@@ -410,6 +420,11 @@ def get_connector_indexing_status(
|
||||
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
|
||||
latest_finished_attempt = get_latest_finished_index_attempt_for_cc_pair(
|
||||
connector_credential_pair_id=cc_pair.id, db_session=db_session
|
||||
)
|
||||
|
||||
indexing_statuses.append(
|
||||
ConnectorIndexingStatus(
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -418,6 +433,9 @@ def get_connector_indexing_status(
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
public_doc=cc_pair.is_public,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
last_finished_status=latest_finished_attempt.status
|
||||
if latest_finished_attempt
|
||||
else None,
|
||||
last_status=latest_index_attempt.status
|
||||
if latest_index_attempt
|
||||
else None,
|
||||
@@ -480,6 +498,35 @@ def create_connector_from_model(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/connector-with-mock-credential")
|
||||
def create_connector_with_mock_credential(
|
||||
connector_data: ConnectorBase,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
connector_response = create_connector(connector_data, db_session)
|
||||
mock_credential = CredentialBase(
|
||||
credential_json={}, admin_public=True, source=connector_data.source
|
||||
)
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
is_public=True,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
cc_pair_name=connector_data.name,
|
||||
)
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/admin/connector/{connector_id}")
|
||||
def update_connector_from_model(
|
||||
connector_id: int,
|
||||
@@ -515,6 +562,7 @@ def update_connector_from_model(
|
||||
credential_ids=[
|
||||
association.credential.id for association in updated_connector.credentials
|
||||
],
|
||||
indexing_start=updated_connector.indexing_start,
|
||||
time_created=updated_connector.time_created,
|
||||
time_updated=updated_connector.time_updated,
|
||||
disabled=updated_connector.disabled,
|
||||
@@ -542,6 +590,7 @@ def connector_run_once(
|
||||
) -> StatusResponse[list[int]]:
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
try:
|
||||
possible_credential_ids = get_connector_credential_ids(
|
||||
run_info.connector_id, db_session
|
||||
@@ -585,16 +634,21 @@ def connector_run_once(
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
index_attempt_ids = [
|
||||
create_index_attempt(
|
||||
connector_id=run_info.connector_id,
|
||||
credential_id=credential_id,
|
||||
connector_credential_pair_id=connector_credential_pair.id,
|
||||
embedding_model_id=embedding_model.id,
|
||||
from_beginning=run_info.from_beginning,
|
||||
db_session=db_session,
|
||||
)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
for connector_credential_pair in connector_credential_pairs
|
||||
if connector_credential_pair is not None
|
||||
]
|
||||
|
||||
if not index_attempt_ids:
|
||||
@@ -724,6 +778,7 @@ def get_connector_by_id(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
indexing_start=connector.indexing_start,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
|
||||
@@ -6,15 +6,21 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.credentials import alter_credential
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.credentials import delete_credential
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.credentials import fetch_credentials
|
||||
from danswer.db.credentials import fetch_credentials_by_source
|
||||
from danswer.db.credentials import swap_credentials_connector
|
||||
from danswer.db.credentials import update_credential
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import DocumentSource
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.documents.models import CredentialSwapRequest
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
|
||||
@@ -38,6 +44,34 @@ def list_credentials_admin(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/similar-credentials/{source_type}")
|
||||
def get_cc_source_full_info(
|
||||
source_type: DocumentSource,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CredentialSnapshot]:
|
||||
credentials = fetch_credentials_by_source(
|
||||
db_session=db_session, user=user, document_source=source_type
|
||||
)
|
||||
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
]
|
||||
|
||||
|
||||
@router.get("/credentials/{id}")
|
||||
def list_credentials_by_id(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CredentialSnapshot]:
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/admin/credential/{credential_id}")
|
||||
def delete_credential_by_id_admin(
|
||||
credential_id: int,
|
||||
@@ -51,6 +85,26 @@ def delete_credential_by_id_admin(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/credentials/swap")
|
||||
def swap_credentials_for_connector(
|
||||
credentail_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
new_credential_id=credentail_swap_req.new_credential_id,
|
||||
connector_id=credentail_swap_req.connector_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Credential swapped successfully",
|
||||
data=connector_credential_pair.id,
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@@ -79,7 +133,11 @@ def create_credential_from_model(
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(id=credential.id)
|
||||
|
||||
return ObjectCreationIdResponse(
|
||||
id=credential.id,
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credential/{credential_id}")
|
||||
@@ -98,6 +156,24 @@ def get_credential_by_id(
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.put("/admin/credentials/{credential_id}")
|
||||
def update_credential_data(
|
||||
credential_id: int,
|
||||
credential_update: CredentialDataUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialBase:
|
||||
credential = alter_credential(credential_id, credential_update, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.patch("/credential/{credential_id}")
|
||||
def update_credential_from_model(
|
||||
credential_id: int,
|
||||
@@ -115,6 +191,7 @@ def update_credential_from_model(
|
||||
)
|
||||
|
||||
return CredentialSnapshot(
|
||||
source=updated_credential.source,
|
||||
id=updated_credential.id,
|
||||
credential_json=updated_credential.credential_json,
|
||||
user_id=updated_credential.user_id,
|
||||
@@ -130,7 +207,25 @@ def delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(credential_id, user, db_session)
|
||||
delete_credential(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/credential/force/{credential_id}")
|
||||
def force_delete_credential_by_id(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(credential_id, user, db_session, True)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_default_llm_token_encode
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.server.documents.models import ChunkInfo
|
||||
|
||||
@@ -26,6 +26,92 @@ class ChunkInfo(BaseModel):
|
||||
num_tokens: int
|
||||
|
||||
|
||||
class DeletionAttemptSnapshot(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
status: TaskStatus
|
||||
|
||||
|
||||
class ConnectorBase(BaseModel):
|
||||
name: str
|
||||
source: DocumentSource
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
refresh_freq: int | None # In seconds, None for one time index with no refresh
|
||||
prune_freq: int | None
|
||||
disabled: bool
|
||||
indexing_start: datetime | None
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
id: int
|
||||
credential_ids: list[int]
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
source: DocumentSource
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
indexing_start=connector.indexing_start,
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
disabled=connector.disabled,
|
||||
)
|
||||
|
||||
|
||||
class CredentialSwapRequest(BaseModel):
|
||||
new_credential_id: int
|
||||
connector_id: int
|
||||
|
||||
|
||||
class CredentialDataUpdateRequest(BaseModel):
|
||||
name: str
|
||||
credential_json: dict[str, Any]
|
||||
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_json: dict[str, Any]
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: bool
|
||||
source: DocumentSource
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
id: int
|
||||
user_id: UUID | None
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
return CredentialSnapshot(
|
||||
id=credential.id,
|
||||
credential_json=(
|
||||
mask_credential_dict(credential.credential_json)
|
||||
if MASK_CREDENTIAL_PREFIX
|
||||
else credential.credential_json
|
||||
),
|
||||
user_id=credential.user_id,
|
||||
admin_public=credential.admin_public,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
source=credential.source or DocumentSource.NOT_APPLICABLE,
|
||||
name=credential.name,
|
||||
)
|
||||
|
||||
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
id: int
|
||||
status: IndexingStatus | None
|
||||
@@ -49,80 +135,15 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index or 0,
|
||||
error_msg=index_attempt.error_msg,
|
||||
full_exception_trace=index_attempt.full_exception_trace,
|
||||
time_started=index_attempt.time_started.isoformat()
|
||||
if index_attempt.time_started
|
||||
else None,
|
||||
time_started=(
|
||||
index_attempt.time_started.isoformat()
|
||||
if index_attempt.time_started
|
||||
else None
|
||||
),
|
||||
time_updated=index_attempt.time_updated.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
class DeletionAttemptSnapshot(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
status: TaskStatus
|
||||
|
||||
|
||||
class ConnectorBase(BaseModel):
|
||||
name: str
|
||||
source: DocumentSource
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
refresh_freq: int | None # In seconds, None for one time index with no refresh
|
||||
prune_freq: int | None
|
||||
disabled: bool
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
id: int
|
||||
credential_ids: list[int]
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
disabled=connector.disabled,
|
||||
)
|
||||
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_json: dict[str, Any]
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: bool
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
id: int
|
||||
user_id: UUID | None
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
return CredentialSnapshot(
|
||||
id=credential.id,
|
||||
credential_json=mask_credential_dict(credential.credential_json)
|
||||
if MASK_CREDENTIAL_PREFIX
|
||||
else credential.credential_json,
|
||||
user_id=credential.user_id,
|
||||
admin_public=credential.admin_public,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
)
|
||||
|
||||
|
||||
class CCPairFullInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@@ -167,6 +188,7 @@ class ConnectorIndexingStatus(BaseModel):
|
||||
credential: CredentialSnapshot
|
||||
owner: str
|
||||
public_doc: bool
|
||||
last_finished_status: IndexingStatus | None
|
||||
last_status: IndexingStatus | None
|
||||
last_success: datetime | None
|
||||
docs_indexed: int
|
||||
@@ -242,6 +264,7 @@ class FileUploadResponse(BaseModel):
|
||||
|
||||
class ObjectCreationIdResponse(BaseModel):
|
||||
id: int | str
|
||||
credential: CredentialSnapshot | None = None
|
||||
|
||||
|
||||
class AuthStatus(BaseModel):
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import create_update_persona
|
||||
@@ -17,6 +20,8 @@ from danswer.db.persona import mark_persona_as_not_deleted
|
||||
from danswer.db.persona import update_all_personas_display_priority
|
||||
from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -24,6 +29,7 @@ from danswer.server.features.persona.models import PromptTemplateResponse
|
||||
from danswer.server.models import DisplayPriorityRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -90,6 +96,26 @@ def undelete_persona(
|
||||
)
|
||||
|
||||
|
||||
# used for assistnat profile pictures
|
||||
@admin_router.post("/upload-image")
|
||||
def upload_file(
|
||||
file: UploadFile,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, str]:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_type = ChatFileType.IMAGE
|
||||
file_id = str(uuid.uuid4())
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=file.content_type or file_type.value,
|
||||
)
|
||||
return {"file_id": file_id}
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ class CreatePersonaRequest(BaseModel):
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] | None = None
|
||||
groups: list[int] | None = None
|
||||
icon_color: str | None = None
|
||||
icon_shape: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
@@ -55,6 +58,9 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSet]
|
||||
users: list[MinimalUserSnapshot]
|
||||
groups: list[int]
|
||||
icon_color: str | None
|
||||
icon_shape: int | None
|
||||
uploaded_image_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -97,6 +103,9 @@ class PersonaSnapshot(BaseModel):
|
||||
for user in persona.users
|
||||
],
|
||||
groups=[user_group.id for user_group in persona.groups],
|
||||
icon_color=persona.icon_color,
|
||||
icon_shape=persona.icon_shape,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
@@ -42,7 +42,7 @@ def test_embedding_configuration(
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
|
||||
@@ -147,10 +147,10 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
_: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||
]
|
||||
|
||||
@@ -32,6 +32,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -48,6 +49,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,6 +62,9 @@ class LLMProvider(BaseModel):
|
||||
custom_config: dict[str, str] | None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_public: bool = True
|
||||
groups: list[int] | None = None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -86,9 +91,12 @@ class FullLLMProvider(LLMProvider):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -14,13 +15,15 @@ from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import User as UserModel
|
||||
pass
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
@@ -46,9 +49,17 @@ class UserInfo(BaseModel):
|
||||
is_verified: bool
|
||||
role: UserRole
|
||||
preferences: UserPreferences
|
||||
oidc_expiry: datetime | None = None
|
||||
current_token_created_at: datetime | None = None
|
||||
current_token_expiry_length: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user: "UserModel") -> "UserInfo":
|
||||
def from_model(
|
||||
cls,
|
||||
user: User,
|
||||
current_token_created_at: datetime | None = None,
|
||||
expiry_length: int | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
@@ -57,6 +68,9 @@ class UserInfo(BaseModel):
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)),
|
||||
oidc_expiry=user.oidc_expiry,
|
||||
current_token_created_at=current_token_created_at,
|
||||
current_token_expiry_length=expiry_length,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +165,9 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# by an optional `PersonaSnapshot` object. Keeping it like this
|
||||
# for now for simplicity / speed of development
|
||||
document_sets: list[int] | None
|
||||
persona_id: int | None # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
persona_id: (
|
||||
int | None
|
||||
) # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
channel_names: list[str]
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
@@ -6,6 +7,9 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -19,9 +23,11 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.db.users import list_users
|
||||
@@ -117,9 +123,9 @@ def list_all_users(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE
|
||||
if user.is_active
|
||||
else UserStatus.DEACTIVATED,
|
||||
status=(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
@@ -246,9 +252,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
||||
return UserRoleResponse(role=user.role)
|
||||
|
||||
|
||||
def get_current_token_creation(
|
||||
user: User | None, db_session: Session
|
||||
) -> datetime | None:
|
||||
if user is None:
|
||||
return None
|
||||
try:
|
||||
result = db_session.execute(
|
||||
select(AccessToken)
|
||||
.where(AccessToken.user_id == user.id) # type: ignore
|
||||
.order_by(desc(Column("created_at")))
|
||||
.limit(1)
|
||||
)
|
||||
access_token = result.scalar_one_or_none()
|
||||
|
||||
if access_token:
|
||||
return access_token.created_at
|
||||
else:
|
||||
logger.error("No AccessToken found for user")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching AccessToken: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(
|
||||
user: User | None = Depends(optional_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserInfo:
|
||||
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
||||
# to enforce user verification here - the frontend always wants to get the info about
|
||||
@@ -264,7 +296,14 @@ def verify_user_logged_in(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
|
||||
)
|
||||
|
||||
return UserInfo.from_model(user)
|
||||
token_created_at = get_current_token_creation(user, db_session)
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
"""APIs to adjust user preferences"""
|
||||
|
||||
@@ -45,7 +45,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
parent_message_id: int | None
|
||||
# New message contents
|
||||
message: str
|
||||
# file's that we should attach to this message
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor]
|
||||
# If no prompt provided, uses the largest prompt of the chat session
|
||||
# but really this should be explicitly specified, only in the simplified APIs is this inferred
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
class ForceUseTool(BaseModel):
|
||||
# Could be not a forced usage of the tool but still have args, in which case
|
||||
# if the tool is called, then those args are applied instead of what the LLM
|
||||
# wanted to call it with
|
||||
force_use: bool
|
||||
tool_name: str
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
@@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
|
||||
return {"type": "function", "function": {"name": self.tool_name}}
|
||||
|
||||
|
||||
def modify_message_chain_for_force_use_tool(
|
||||
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
|
||||
) -> list[BaseMessage]:
|
||||
"""NOTE: modifies `messages` in place."""
|
||||
if not force_use_tool:
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call["args"] = force_use_tool.args or {}
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
tools: list[Tool], force_use_tool: ForceUseTool | None = None
|
||||
tools: list[Tool], force_use_tool: ForceUseTool
|
||||
) -> list[Tool]:
|
||||
if not force_use_tool:
|
||||
if not force_use_tool.force_use:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
|
||||
@@ -156,7 +156,9 @@ class ImageGenerationTool(Tool):
|
||||
for image_generation in image_generations
|
||||
]
|
||||
),
|
||||
img_urls=[image_generation.url for image_generation in image_generations],
|
||||
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
|
||||
# Tool messages to contain images
|
||||
# img_urls=[image_generation.url for image_generation in image_generations],
|
||||
)
|
||||
|
||||
def _generate_image(self, prompt: str) -> ImageGenerationResponse:
|
||||
|
||||
@@ -10,7 +10,8 @@ Can you please summarize them in a sentence or two?
|
||||
"""
|
||||
|
||||
TOOL_CALLING_PROMPT = """
|
||||
Can you please summarize the two images you generate in a sentence or two?
|
||||
Can you please summarize the two images you just generated in a sentence or two? Do not use a
|
||||
numbered list.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
|
||||
|
||||
def build_tool_message(
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
|
||||
from tiktoken import Encoding
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,24 @@ def fetch_versioned_implementation(module: str, attribute: str) -> Any:
|
||||
module_full = f"ee.{module}" if is_ee else module
|
||||
try:
|
||||
return getattr(importlib.import_module(module_full), attribute)
|
||||
except ModuleNotFoundError:
|
||||
# try the non-ee version as a fallback
|
||||
except ModuleNotFoundError as e:
|
||||
logger.warning(
|
||||
"Failed to fetch versioned implementation for %s.%s: %s",
|
||||
module_full,
|
||||
attribute,
|
||||
e,
|
||||
)
|
||||
|
||||
if is_ee:
|
||||
if "ee.danswer" not in str(e):
|
||||
# If it's a non Danswer related import failure, this is likely because
|
||||
# a dependent library has not been installed. Should raise this failure
|
||||
# instead of letting the server start up
|
||||
raise e
|
||||
|
||||
# Use the MIT version as a fallback, this allows us to develop MIT
|
||||
# versions independently and later add additional EE functionality
|
||||
# similar to feature flagging
|
||||
return getattr(importlib.import_module(module), attribute)
|
||||
|
||||
raise
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
@@ -194,6 +195,15 @@ def _cleanup_user__user_group_relationships__no_commit(
|
||||
db_session.delete(user__user_group_relationship)
|
||||
|
||||
|
||||
def _cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.user_group_id == user_group_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def _mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> None:
|
||||
@@ -316,6 +326,9 @@ def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> Non
|
||||
|
||||
|
||||
def delete_user_group(db_session: Session, user_group: UserGroup) -> None:
|
||||
_cleanup_llm_provider__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
_cleanup_user__user_group_relationships__no_commit(
|
||||
db_session=db_session, user_group_id=user_group.id
|
||||
)
|
||||
|
||||
@@ -6,9 +6,9 @@ from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import danswer.db.models as db_models
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.analytics import fetch_danswerbot_analytics
|
||||
from ee.danswer.db.analytics import fetch_per_user_query_analytics
|
||||
from ee.danswer.db.analytics import fetch_query_analytics
|
||||
@@ -27,7 +27,7 @@ class QueryAnalyticsResponse(BaseModel):
|
||||
def get_query_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryAnalyticsResponse]:
|
||||
daily_query_usage_info = fetch_query_analytics(
|
||||
@@ -58,7 +58,7 @@ class UserAnalyticsResponse(BaseModel):
|
||||
def get_user_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserAnalyticsResponse]:
|
||||
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
|
||||
@@ -92,7 +92,7 @@ class DanswerbotAnalyticsResponse(BaseModel):
|
||||
def get_danswerbot_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[DanswerbotAnalyticsResponse]:
|
||||
daily_danswerbot_info = fetch_danswerbot_analytics(
|
||||
|
||||
@@ -2,9 +2,9 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import danswer.db.models as db_models
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.api_key import ApiKeyDescriptor
|
||||
from ee.danswer.db.api_key import fetch_api_keys
|
||||
from ee.danswer.db.api_key import insert_api_key
|
||||
@@ -13,12 +13,13 @@ from ee.danswer.db.api_key import remove_api_key
|
||||
from ee.danswer.db.api_key import update_api_key
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
|
||||
|
||||
router = APIRouter(prefix="/admin/api-key")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_api_keys(
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ApiKeyDescriptor]:
|
||||
return fetch_api_keys(db_session)
|
||||
@@ -27,7 +28,7 @@ def list_api_keys(
|
||||
@router.post("")
|
||||
def create_api_key(
|
||||
api_key_args: APIKeyArgs,
|
||||
user: db_models.User | None = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return insert_api_key(db_session, api_key_args, user.id if user else None)
|
||||
@@ -36,7 +37,7 @@ def create_api_key(
|
||||
@router.post("/{api_key_id}/regenerate")
|
||||
def regenerate_existing_api_key(
|
||||
api_key_id: int,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return regenerate_api_key(db_session, api_key_id)
|
||||
@@ -46,7 +47,7 @@ def regenerate_existing_api_key(
|
||||
def update_existing_api_key(
|
||||
api_key_id: int,
|
||||
api_key_args: APIKeyArgs,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ApiKeyDescriptor:
|
||||
return update_api_key(db_session, api_key_id, api_key_args)
|
||||
@@ -55,7 +56,7 @@ def update_existing_api_key(
|
||||
@router.delete("/{api_key_id}")
|
||||
def delete_api_key(
|
||||
api_key_id: int,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
remove_api_key(db_session, api_key_id)
|
||||
|
||||
@@ -12,7 +12,6 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import danswer.db.models as db_models
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import get_display_email
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
@@ -22,9 +21,9 @@ from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.query_history import fetch_chat_sessions_eagerly_by_time
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -303,7 +302,7 @@ def get_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[ChatSessionMinimal]:
|
||||
return fetch_and_process_chat_session_history_minimal(
|
||||
@@ -320,7 +319,7 @@ def get_chat_session_history(
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: int,
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
try:
|
||||
@@ -349,7 +348,7 @@ def get_chat_session_admin(
|
||||
|
||||
@router.get("/admin/query-history-csv")
|
||||
def get_query_history_as_csv(
|
||||
_: db_models.User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
|
||||
@@ -4,9 +4,9 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import danswer.db.models as db_models
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
from ee.danswer.db.user_group import insert_user_group
|
||||
from ee.danswer.db.user_group import prepare_user_group_for_deletion
|
||||
@@ -20,7 +20,7 @@ router = APIRouter(prefix="/manage")
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
user_groups = fetch_user_groups(db_session, only_current=False)
|
||||
@@ -30,7 +30,7 @@ def list_user_groups(
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -48,7 +48,7 @@ def create_user_group(
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
user_group: UserGroupUpdate,
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -62,7 +62,7 @@ def patch_user_group(
|
||||
@router.delete("/admin/user-group/{user_group_id}")
|
||||
def delete_user_group(
|
||||
user_group_id: int,
|
||||
_: db_models.User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
|
||||
@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
@@ -40,110 +41,133 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> Any:
|
||||
if provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(api_key)
|
||||
)
|
||||
project_id = json.loads(api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
# Only for Google as is needed on client setup
|
||||
self.model = model
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = self._initialize_client()
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
# OpenAI does not seem to provide truncation option, however
|
||||
# the context lengths used by Danswer currently are smaller than the max token length
|
||||
# for OpenAI embeddings so it's not a big deal
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
|
||||
def _embed_cohere(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = self.client.embed(
|
||||
texts=[text],
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
return response.embeddings[0]
|
||||
return response.embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
# Similar to Cohere, the API server will do approximate size chunking
|
||||
# it's acceptable to miss by a few tokens
|
||||
response = self.client.embed(
|
||||
texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True, # Also this is default
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embedding = self.client.get_embeddings(
|
||||
embeddings = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
]
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # Also this is default
|
||||
)
|
||||
return embedding[0].values
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float] | None]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -212,34 +236,83 @@ def embed_text(
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
) -> list[list[float]]:
|
||||
if provider_type is not None:
|
||||
prefix: str | None,
|
||||
) -> list[list[float] | None]:
|
||||
non_empty_texts = []
|
||||
empty_indices = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
if text.strip():
|
||||
non_empty_texts.append(text)
|
||||
else:
|
||||
empty_indices.append(idx)
|
||||
|
||||
# Third party API based embedding model
|
||||
if not non_empty_texts:
|
||||
embeddings = []
|
||||
elif provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
# This may change in the future if some providers require the user
|
||||
# to manually append a prefix but this is not the case currently
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=non_empty_texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
elif model_name is not None:
|
||||
hosted_model = get_embedding_model(
|
||||
prefixed_texts = (
|
||||
[f"{prefix}{text}" for text in non_empty_texts]
|
||||
if prefix
|
||||
else non_empty_texts
|
||||
)
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
embeddings = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
embeddings_with_nulls: list[list[float] | None] = []
|
||||
current_embedding_index = 0
|
||||
|
||||
for idx in range(len(texts)):
|
||||
if idx in empty_indices:
|
||||
embeddings_with_nulls.append(None)
|
||||
else:
|
||||
embedding = embeddings[current_embedding_index]
|
||||
if isinstance(embedding, list) or embedding is None:
|
||||
embeddings_with_nulls.append(embedding)
|
||||
else:
|
||||
embeddings_with_nulls.append(embedding.tolist())
|
||||
current_embedding_index += 1
|
||||
|
||||
embeddings = embeddings_with_nulls
|
||||
return embeddings
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
@@ -252,7 +325,17 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
try:
|
||||
if embed_request.text_type == EmbedTextType.QUERY:
|
||||
prefix = embed_request.manual_query_prefix
|
||||
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
prefix = embed_request.manual_passage_prefix
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
@@ -261,13 +344,13 @@ async def process_embed_request(
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during embedding process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Bi-Encoder embedding"
|
||||
)
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
@@ -276,6 +359,11 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or query to be reranked"
|
||||
)
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
fastapi==0.109.2
|
||||
h5py==3.9.0
|
||||
pydantic==1.10.13
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
tensorflow==2.15.0
|
||||
@@ -9,5 +10,5 @@ transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
openai==1.14.3
|
||||
cohere==5.5.8
|
||||
google-cloud-aiplatform==1.58.0
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
|
||||
@@ -21,18 +21,17 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
cmd_worker = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--autoscale=3,10",
|
||||
"--concurrency=16",
|
||||
"--loglevel=INFO",
|
||||
"--concurrency=1",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
"ee.danswer.background.celery",
|
||||
"ee.danswer.background.celery.celery_app",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
@@ -74,7 +73,7 @@ def run_jobs(exclude_indexing: bool) -> None:
|
||||
try:
|
||||
update_env = os.environ.copy()
|
||||
update_env["PYTHONPATH"] = "."
|
||||
cmd_perm_sync = ["python", "ee.danswer/background/permission_sync.py"]
|
||||
cmd_perm_sync = ["python", "ee/danswer/background/permission_sync.py"]
|
||||
|
||||
indexing_process = subprocess.Popen(
|
||||
cmd_perm_sync,
|
||||
|
||||
@@ -14,7 +14,10 @@ sys.path.append(parent_dir)
|
||||
# flake8: noqa: E402
|
||||
|
||||
# Now import Danswer modules
|
||||
from danswer.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from danswer.db.models import (
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
UserGroup__ConnectorCredentialPair,
|
||||
)
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.index_attempt import (
|
||||
@@ -44,7 +47,7 @@ logger = setup_logger()
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def unsafe_deletion(
|
||||
def _unsafe_deletion(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -82,11 +85,22 @@ def unsafe_deletion(
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# Delete document sets + connector / credential Pairs
|
||||
# Delete document sets
|
||||
stmt = delete(DocumentSet__ConnectorCredentialPair).where(
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == pair_id
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# delete user group associations
|
||||
stmt = delete(UserGroup__ConnectorCredentialPair).where(
|
||||
UserGroup__ConnectorCredentialPair.cc_pair_id == pair_id
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
# need to flush to avoid foreign key violations
|
||||
db_session.flush()
|
||||
|
||||
# delete the actual connector credential pair
|
||||
stmt = delete(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
@@ -168,7 +182,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
files_deleted_count = unsafe_deletion(
|
||||
files_deleted_count = _unsafe_deletion(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
@@ -14,10 +12,12 @@ class EmbedRequest(BaseModel):
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
embeddings: list[list[float]]
|
||||
embeddings: list[list[float] | None]
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
@@ -26,7 +26,7 @@ class RerankRequest(BaseModel):
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[list[float]]
|
||||
scores: list[list[float] | None]
|
||||
|
||||
|
||||
class IntentRequest(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ autorestart=true
|
||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||
# Vespa / Postgres)
|
||||
[program:celery_worker]
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --autoscale=3,10 --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads --concurrency=16 --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
stdout_logfile=/var/log/celery_worker_supervisor.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
|
||||
@@ -9,7 +9,7 @@ This Python script automates the process of running search quality tests for a b
|
||||
- Manages environment variables
|
||||
- Switches to specified Git branch
|
||||
- Uploads test documents
|
||||
- Runs search quality tests using Relari
|
||||
- Runs search quality tests
|
||||
- Cleans up Docker containers (optional)
|
||||
|
||||
## Usage
|
||||
@@ -29,9 +29,17 @@ export PYTHONPATH=$PYTHONPATH:$PWD/backend
|
||||
```
|
||||
cd backend/tests/regression/answer_quality
|
||||
```
|
||||
7. Run the script:
|
||||
7. To launch the evaluation environment, run the launch_eval_env.py script (this step can be skipped if you are running the env outside of docker, just leave "environment_name" blank):
|
||||
```
|
||||
python run_eval_pipeline.py
|
||||
python launch_eval_env.py
|
||||
```
|
||||
8. Run the file_uploader.py script to upload the zip files located at the path "zipped_documents_file"
|
||||
```
|
||||
python file_uploader.py
|
||||
```
|
||||
9. Run the run_qa.py script to ask questions from the jsonl located at the path "questions_file". This will hit the "query/answer-with-quote" API endpoint.
|
||||
```
|
||||
python run_qa.py
|
||||
```
|
||||
|
||||
Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data.
|
||||
@@ -61,6 +69,11 @@ Edit `search_test_config.yaml` to set:
|
||||
- Set this to true to automatically delete all docker containers, networks and volumes after the test
|
||||
- launch_web_ui
|
||||
- Set this to true if you want to use the UI during/after the testing process
|
||||
- only_state
|
||||
- Whether to only run Vespa and Postgres
|
||||
- only_retrieve_docs
|
||||
- Set true to only retrieve documents, not LLM response
|
||||
- This is to save on API costs
|
||||
- use_cloud_gpu
|
||||
- Set to true or false depending on if you want to use the remote gpu
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
@@ -70,12 +83,10 @@ Edit `search_test_config.yaml` to set:
|
||||
- model_server_port
|
||||
- This is the port of the remote model server
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
- existing_test_suffix (THIS IS NOT A SUFFIX ANYMORE, TODO UPDATE THE DOCS HERE)
|
||||
- environment_name
|
||||
- Use this if you would like to relaunch a previous test instance
|
||||
- Input the suffix of the test you'd like to re-launch
|
||||
- (E.g. to use the data from folder "test-1234-5678" put "-1234-5678")
|
||||
- No new files will automatically be uploaded
|
||||
- Leave empty to run a new test
|
||||
- Input the env_name of the test you'd like to re-launch
|
||||
- Leave empty to launch referencing local default network locations
|
||||
- limit
|
||||
- Max number of questions you'd like to ask against the dataset
|
||||
- Set to null for no limit
|
||||
@@ -85,7 +96,7 @@ Edit `search_test_config.yaml` to set:
|
||||
|
||||
## Relaunching From Existing Data
|
||||
|
||||
To launch an existing set of containers that has already completed indexing, set the existing_test_suffix variable. This will launch the docker containers mounted on the volumes of the indicated suffix and will not automatically index any documents or run any QA.
|
||||
To launch an existing set of containers that has already completed indexing, set the environment_name variable. This will launch the docker containers mounted on the volumes of the indicated env_name and will not automatically index any documents or run any QA.
|
||||
|
||||
Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above).
|
||||
- file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path.
|
||||
|
||||
@@ -2,45 +2,31 @@ import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from tests.regression.answer_quality.cli_utils import get_api_server_host_port
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def _api_url_builder(run_suffix: str, api_path: str) -> str:
|
||||
return f"http://localhost:{get_api_server_host_port(run_suffix)}" + api_path
|
||||
|
||||
|
||||
def _create_new_chat_session(run_suffix: str) -> int:
|
||||
create_chat_request = ChatSessionCreationRequest(
|
||||
persona_id=0,
|
||||
description=None,
|
||||
)
|
||||
body = create_chat_request.dict()
|
||||
|
||||
create_chat_url = _api_url_builder(run_suffix, "/chat/create-chat-session/")
|
||||
|
||||
response_json = requests.post(
|
||||
create_chat_url, headers=GENERAL_HEADERS, json=body
|
||||
).json()
|
||||
chat_session_id = response_json.get("chat_session_id")
|
||||
|
||||
if isinstance(chat_session_id, int):
|
||||
return chat_session_id
|
||||
def _api_url_builder(env_name: str, api_path: str) -> str:
|
||||
if env_name:
|
||||
return f"http://localhost:{get_api_server_host_port(env_name)}" + api_path
|
||||
else:
|
||||
raise RuntimeError(response_json)
|
||||
return "http://localhost:8080" + api_path
|
||||
|
||||
|
||||
@retry(tries=10, delay=10)
|
||||
def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
@retry(tries=5, delay=5)
|
||||
def get_answer_from_query(
|
||||
query: str, only_retrieve_docs: bool, env_name: str
|
||||
) -> tuple[list[str], str]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
@@ -48,42 +34,47 @@ def get_answer_from_query(query: str, run_suffix: str) -> tuple[list[str], str]:
|
||||
tags=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
retrieval_options = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
chain_of_thought=False,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=only_retrieve_docs,
|
||||
)
|
||||
|
||||
chat_session_id = _create_new_chat_session(run_suffix)
|
||||
|
||||
url = _api_url_builder(run_suffix, "/chat/send-message-simple-api/")
|
||||
|
||||
new_message_request = BasicCreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
message=query,
|
||||
retrieval_options=retrieval_options,
|
||||
query_override=query,
|
||||
)
|
||||
url = _api_url_builder(env_name, "/query/answer-with-quote/")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = new_message_request.dict()
|
||||
body["user"] = None
|
||||
try:
|
||||
response_json = requests.post(url, headers=GENERAL_HEADERS, json=body).json()
|
||||
simple_search_docs = response_json.get("simple_search_docs", [])
|
||||
answer = response_json.get("answer", "")
|
||||
response_json = requests.post(url, headers=headers, json=body).json()
|
||||
context_data_list = response_json.get("contexts", {}).get("contexts", [])
|
||||
answer = response_json.get("answer", "") or ""
|
||||
except Exception as e:
|
||||
print("Failed to answer the questions:")
|
||||
print(f"\t {str(e)}")
|
||||
print("trying again")
|
||||
print("Try restarting vespa container and trying agian")
|
||||
raise e
|
||||
|
||||
return simple_search_docs, answer
|
||||
return context_data_list, answer
|
||||
|
||||
|
||||
@retry(tries=10, delay=10)
|
||||
def check_if_query_ready(run_suffix: str) -> bool:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/indexing-status/")
|
||||
def check_indexing_status(env_name: str) -> tuple[int, bool]:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector/indexing-status/")
|
||||
try:
|
||||
indexing_status_dict = requests.get(url, headers=GENERAL_HEADERS).json()
|
||||
except Exception as e:
|
||||
@@ -98,20 +89,21 @@ def check_if_query_ready(run_suffix: str) -> bool:
|
||||
status = index_attempt["last_status"]
|
||||
if status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED:
|
||||
ongoing_index_attempts = True
|
||||
elif status == IndexingStatus.SUCCESS:
|
||||
doc_count += 16
|
||||
doc_count += index_attempt["docs_indexed"]
|
||||
doc_count -= 16
|
||||
|
||||
if not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
elif ongoing_index_attempts:
|
||||
print(
|
||||
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
|
||||
)
|
||||
|
||||
return doc_count > 0 and not ongoing_index_attempts
|
||||
# all the +16 and -16 are to account for the fact that the indexing status
|
||||
# is only updated every 16 documents and will tells us how many are
|
||||
# chunked, not indexed. probably need to fix this. in the future!
|
||||
if doc_count:
|
||||
doc_count += 16
|
||||
return doc_count, ongoing_index_attempts
|
||||
|
||||
|
||||
def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector/run-once/")
|
||||
def run_cc_once(env_name: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector/run-once/")
|
||||
body = {
|
||||
"connector_id": connector_id,
|
||||
"credential_ids": [credential_id],
|
||||
@@ -126,9 +118,9 @@ def run_cc_once(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
print("Failed text:", response.text)
|
||||
|
||||
|
||||
def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> None:
|
||||
def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None:
|
||||
url = _api_url_builder(
|
||||
run_suffix, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
)
|
||||
|
||||
body = {"name": "zip_folder_contents", "is_public": True}
|
||||
@@ -141,8 +133,8 @@ def create_cc_pair(run_suffix: str, connector_id: int, credential_id: int) -> No
|
||||
print("Failed text:", response.text)
|
||||
|
||||
|
||||
def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
url = _api_url_builder(run_suffix, "/manage/connector")
|
||||
def _get_existing_connector_names(env_name: str) -> list[str]:
|
||||
url = _api_url_builder(env_name, "/manage/connector")
|
||||
|
||||
body = {
|
||||
"credential_json": {},
|
||||
@@ -156,10 +148,10 @@ def _get_existing_connector_names(run_suffix: str) -> list[str]:
|
||||
raise RuntimeError(response.__dict__)
|
||||
|
||||
|
||||
def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/admin/connector")
|
||||
def create_connector(env_name: str, file_paths: list[str]) -> int:
|
||||
url = _api_url_builder(env_name, "/manage/admin/connector")
|
||||
connector_name = base_connector_name = "search_eval_connector"
|
||||
existing_connector_names = _get_existing_connector_names(run_suffix)
|
||||
existing_connector_names = _get_existing_connector_names(env_name)
|
||||
|
||||
count = 1
|
||||
while connector_name in existing_connector_names:
|
||||
@@ -174,6 +166,7 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
disabled=False,
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
body = connector.dict()
|
||||
@@ -186,8 +179,8 @@ def create_connector(run_suffix: str, file_paths: list[str]) -> int:
|
||||
raise RuntimeError(response.__dict__)
|
||||
|
||||
|
||||
def create_credential(run_suffix: str) -> int:
|
||||
url = _api_url_builder(run_suffix, "/manage/credential")
|
||||
def create_credential(env_name: str) -> int:
|
||||
url = _api_url_builder(env_name, "/manage/credential")
|
||||
body = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
@@ -201,12 +194,12 @@ def create_credential(run_suffix: str) -> int:
|
||||
|
||||
|
||||
@retry(tries=10, delay=2, backoff=2)
|
||||
def upload_file(run_suffix: str, zip_file_path: str) -> list[str]:
|
||||
def upload_file(env_name: str, zip_file_path: str) -> list[str]:
|
||||
files = [
|
||||
("files", open(zip_file_path, "rb")),
|
||||
]
|
||||
|
||||
api_path = _api_url_builder(run_suffix, "/manage/admin/connector/file/upload")
|
||||
api_path = _api_url_builder(env_name, "/manage/admin/connector/file/upload")
|
||||
try:
|
||||
response = requests.post(api_path, files=files)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
|
||||
@@ -67,20 +67,20 @@ def switch_to_commit(commit_sha: str) -> None:
|
||||
print("Repository updated successfully.")
|
||||
|
||||
|
||||
def get_docker_container_env_vars(suffix: str) -> dict:
|
||||
def get_docker_container_env_vars(env_name: str) -> dict:
|
||||
"""
|
||||
Retrieves environment variables from "background" and "api_server" Docker containers.
|
||||
"""
|
||||
print(f"Getting environment variables for containers with suffix: {suffix}")
|
||||
print(f"Getting environment variables for containers with env_name: {env_name}")
|
||||
|
||||
combined_env_vars = {}
|
||||
for container_type in ["background", "api_server"]:
|
||||
container_name = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{suffix}/'"
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{env_name}/'"
|
||||
)[0].strip()
|
||||
if not container_name:
|
||||
raise RuntimeError(
|
||||
f"No {container_type} container found with suffix: {suffix}"
|
||||
f"No {container_type} container found with env_name: {env_name}"
|
||||
)
|
||||
|
||||
env_vars_json = _run_command(
|
||||
@@ -95,9 +95,9 @@ def get_docker_container_env_vars(suffix: str) -> dict:
|
||||
return combined_env_vars
|
||||
|
||||
|
||||
def manage_data_directories(suffix: str, base_path: str, use_cloud_gpu: bool) -> None:
|
||||
def manage_data_directories(env_name: str, base_path: str, use_cloud_gpu: bool) -> None:
|
||||
# Use the user's home directory as the base path
|
||||
target_path = os.path.join(os.path.expanduser(base_path), suffix)
|
||||
target_path = os.path.join(os.path.expanduser(base_path), env_name)
|
||||
directories = {
|
||||
"DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"),
|
||||
"DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"),
|
||||
@@ -144,12 +144,12 @@ def _is_port_in_use(port: int) -> bool:
|
||||
|
||||
|
||||
def start_docker_compose(
|
||||
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
|
||||
env_name: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False
|
||||
) -> None:
|
||||
print("Starting Docker Compose...")
|
||||
os.chdir(os.path.dirname(__file__))
|
||||
os.chdir("../../../../deployment/docker_compose/")
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{run_suffix} up -d"
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack-{env_name} up -d"
|
||||
command += " --build"
|
||||
command += " --force-recreate"
|
||||
|
||||
@@ -175,17 +175,17 @@ def start_docker_compose(
|
||||
print("Containers have been launched")
|
||||
|
||||
|
||||
def cleanup_docker(run_suffix: str) -> None:
|
||||
def cleanup_docker(env_name: str) -> None:
|
||||
print(
|
||||
f"Deleting Docker containers, volumes, and networks for project suffix: {run_suffix}"
|
||||
f"Deleting Docker containers, volumes, and networks for project env_name: {env_name}"
|
||||
)
|
||||
|
||||
stdout, _ = _run_command("docker ps -a --format '{{json .}}'")
|
||||
|
||||
containers = [json.loads(line) for line in stdout.splitlines()]
|
||||
if not run_suffix:
|
||||
run_suffix = datetime.now().strftime("-%Y")
|
||||
project_name = f"danswer-stack{run_suffix}"
|
||||
if not env_name:
|
||||
env_name = datetime.now().strftime("-%Y")
|
||||
project_name = f"danswer-stack{env_name}"
|
||||
containers_to_delete = [
|
||||
c for c in containers if c["Names"].startswith(project_name)
|
||||
]
|
||||
@@ -221,23 +221,23 @@ def cleanup_docker(run_suffix: str) -> None:
|
||||
|
||||
networks = stdout.splitlines()
|
||||
|
||||
networks_to_delete = [n for n in networks if run_suffix in n]
|
||||
networks_to_delete = [n for n in networks if env_name in n]
|
||||
|
||||
if not networks_to_delete:
|
||||
print(f"No networks found containing suffix: {run_suffix}")
|
||||
print(f"No networks found containing env_name: {env_name}")
|
||||
else:
|
||||
network_names = " ".join(networks_to_delete)
|
||||
_run_command(f"docker network rm {network_names}")
|
||||
|
||||
print(
|
||||
f"Successfully deleted {len(networks_to_delete)} networks containing suffix: {run_suffix}"
|
||||
f"Successfully deleted {len(networks_to_delete)} networks containing env_name: {env_name}"
|
||||
)
|
||||
|
||||
|
||||
@retry(tries=5, delay=5, backoff=2)
|
||||
def get_api_server_host_port(suffix: str) -> str:
|
||||
def get_api_server_host_port(env_name: str) -> str:
|
||||
"""
|
||||
This pulls all containers with the provided suffix
|
||||
This pulls all containers with the provided env_name
|
||||
It then grabs the JSON specific container with a name containing "api_server"
|
||||
It then grabs the port info from the JSON and strips out the relevent data
|
||||
"""
|
||||
@@ -248,16 +248,16 @@ def get_api_server_host_port(suffix: str) -> str:
|
||||
server_jsons = []
|
||||
|
||||
for container in containers:
|
||||
if container_name in container["Names"] and suffix in container["Names"]:
|
||||
if container_name in container["Names"] and env_name in container["Names"]:
|
||||
server_jsons.append(container)
|
||||
|
||||
if not server_jsons:
|
||||
raise RuntimeError(
|
||||
f"No container found containing: {container_name} and {suffix}"
|
||||
f"No container found containing: {container_name} and {env_name}"
|
||||
)
|
||||
elif len(server_jsons) > 1:
|
||||
raise RuntimeError(
|
||||
f"Too many containers matching {container_name} found, please indicate a suffix"
|
||||
f"Too many containers matching {container_name} found, please indicate a env_name"
|
||||
)
|
||||
server_json = server_jsons[0]
|
||||
|
||||
@@ -278,67 +278,37 @@ def get_api_server_host_port(suffix: str) -> str:
|
||||
raise RuntimeError(f"Too many ports matching {client_port} found")
|
||||
if not matching_ports:
|
||||
raise RuntimeError(
|
||||
f"No port found containing: {client_port} for container: {container_name} and suffix: {suffix}"
|
||||
f"No port found containing: {client_port} for container: {container_name} and env_name: {env_name}"
|
||||
)
|
||||
return matching_ports[0]
|
||||
|
||||
|
||||
# Added function to check Vespa container health status
|
||||
def is_vespa_container_healthy(suffix: str) -> bool:
|
||||
print(f"Checking health status of Vespa container for suffix: {suffix}")
|
||||
|
||||
# Find the Vespa container
|
||||
stdout, _ = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
|
||||
)
|
||||
container_name = stdout.strip()
|
||||
|
||||
if not container_name:
|
||||
print(f"No Vespa container found with suffix: {suffix}")
|
||||
return False
|
||||
|
||||
# Get the health status
|
||||
stdout, _ = _run_command(
|
||||
f"docker inspect --format='{{{{.State.Health.Status}}}}' {container_name}"
|
||||
)
|
||||
health_status = stdout.strip()
|
||||
|
||||
is_healthy = health_status.lower() == "healthy"
|
||||
print(f"Vespa container '{container_name}' health status: {health_status}")
|
||||
|
||||
return is_healthy
|
||||
|
||||
|
||||
# Added function to restart Vespa container
|
||||
def restart_vespa_container(suffix: str) -> None:
|
||||
print(f"Restarting Vespa container for suffix: {suffix}")
|
||||
def restart_vespa_container(env_name: str) -> None:
|
||||
print(f"Restarting Vespa container for env_name: {env_name}")
|
||||
|
||||
# Find the Vespa container
|
||||
stdout, _ = _run_command(
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk /vespa/ && /{suffix}/"
|
||||
f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{env_name}/'"
|
||||
)
|
||||
container_name = stdout.strip()
|
||||
|
||||
if not container_name:
|
||||
raise RuntimeError(f"No Vespa container found with suffix: {suffix}")
|
||||
raise RuntimeError(f"No Vespa container found with env_name: {env_name}")
|
||||
|
||||
# Restart the container
|
||||
_run_command(f"docker restart {container_name}")
|
||||
|
||||
print(f"Vespa container '{container_name}' has begun restarting")
|
||||
|
||||
time_to_wait = 5
|
||||
while not is_vespa_container_healthy(suffix):
|
||||
print(f"Waiting {time_to_wait} seconds for vespa container to restart")
|
||||
time.sleep(5)
|
||||
|
||||
time.sleep(30)
|
||||
print(f"Vespa container '{container_name}' has been restarted")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Running this just cleans up the docker environment for the container indicated by existing_test_suffix
|
||||
If no existing_test_suffix is indicated, will just clean up all danswer docker containers/volumes/networks
|
||||
Running this just cleans up the docker environment for the container indicated by environment_name
|
||||
If no environment_name is indicated, will just clean up all danswer docker containers/volumes/networks
|
||||
Note: vespa/postgres mounts are not deleted
|
||||
"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -348,4 +318,4 @@ if __name__ == "__main__":
|
||||
|
||||
if not isinstance(config, dict):
|
||||
raise TypeError("config must be a dictionary")
|
||||
cleanup_docker(config["existing_test_suffix"])
|
||||
cleanup_docker(config["environment_name"])
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.regression.answer_quality.api_utils import check_indexing_status
|
||||
from tests.regression.answer_quality.api_utils import create_cc_pair
|
||||
from tests.regression.answer_quality.api_utils import create_connector
|
||||
from tests.regression.answer_quality.api_utils import create_credential
|
||||
@@ -10,15 +15,65 @@ from tests.regression.answer_quality.api_utils import run_cc_once
|
||||
from tests.regression.answer_quality.api_utils import upload_file
|
||||
|
||||
|
||||
def upload_test_files(zip_file_path: str, run_suffix: str) -> None:
|
||||
def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
|
||||
persistent_dir = tempfile.mkdtemp()
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(persistent_dir)
|
||||
|
||||
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
|
||||
|
||||
|
||||
def create_temp_zip_from_files(file_paths: list[str]) -> str:
|
||||
persistent_dir = tempfile.mkdtemp()
|
||||
zip_file_path = os.path.join(persistent_dir, "temp.zip")
|
||||
|
||||
with zipfile.ZipFile(zip_file_path, "w") as zip_file:
|
||||
for file_path in file_paths:
|
||||
zip_file.write(file_path, Path(file_path).name)
|
||||
|
||||
return zip_file_path
|
||||
|
||||
|
||||
def upload_test_files(zip_file_path: str, env_name: str) -> None:
|
||||
print("zip:", zip_file_path)
|
||||
file_paths = upload_file(run_suffix, zip_file_path)
|
||||
file_paths = upload_file(env_name, zip_file_path)
|
||||
|
||||
conn_id = create_connector(run_suffix, file_paths)
|
||||
cred_id = create_credential(run_suffix)
|
||||
conn_id = create_connector(env_name, file_paths)
|
||||
cred_id = create_credential(env_name)
|
||||
|
||||
create_cc_pair(run_suffix, conn_id, cred_id)
|
||||
run_cc_once(run_suffix, conn_id, cred_id)
|
||||
create_cc_pair(env_name, conn_id, cred_id)
|
||||
run_cc_once(env_name, conn_id, cred_id)
|
||||
|
||||
|
||||
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
|
||||
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
|
||||
total_file_count = len(unzipped_file_paths)
|
||||
|
||||
while True:
|
||||
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
|
||||
|
||||
if ongoing_index_attempts:
|
||||
print(
|
||||
f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..."
|
||||
)
|
||||
elif not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
upload_test_files(zip_file_path, env_name)
|
||||
elif doc_count < total_file_count:
|
||||
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
|
||||
remaining_files = unzipped_file_paths[doc_count:]
|
||||
print(f"Grabbed last {len(remaining_files)} docs to try agian")
|
||||
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
|
||||
upload_test_files(temp_zip_file_path, env_name)
|
||||
os.unlink(temp_zip_file_path)
|
||||
else:
|
||||
print(f"Successfully uploaded {doc_count} docs!")
|
||||
break
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
for file in unzipped_file_paths:
|
||||
os.unlink(file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -27,5 +82,5 @@ if __name__ == "__main__":
|
||||
with open(config_path, "r") as file:
|
||||
config = SimpleNamespace(**yaml.safe_load(file))
|
||||
file_location = config.zipped_documents_file
|
||||
run_suffix = config.existing_test_suffix
|
||||
upload_test_files(file_location, run_suffix)
|
||||
env_name = config.environment_name
|
||||
manage_file_upload(file_location, env_name)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user