Compare commits

..

50 Commits

Author SHA1 Message Date
pablonyx
9087320a06 fix 2025-03-06 14:46:20 -08:00
pablonyx
b0af1458c0 ensure checks pass 2025-03-06 14:46:20 -08:00
pablonyx
bb67a7a122 remove unnecessary logs 2025-03-06 14:46:20 -08:00
pablonyx
e239dc31c1 rename 2025-03-06 14:46:19 -08:00
pablonyx
027128502c add csl 2025-03-06 14:46:19 -08:00
Chris Weaver
a7a374dc81 Confluence fixes (#4220)
* Confluence fixes

* Small tweak

* Address greptile comments
2025-03-06 20:57:07 +00:00
rkuo-danswer
facc8cc2fa add scope needed for permission sync (#4198)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 20:03:38 +00:00
rkuo-danswer
2c0af0a0ca Feature/helm updates (#4201)
* add ingress for api and web

* helm setup docs

* add letsencrypt. close blocks

* use pathType ImplementationSpecific as Prefix is deprecated

* fix backend labels. configure nginx routes. update annotations

* fix linting

---------

Co-authored-by: Sajjad Anwar <sajjadkm@gmail.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 19:48:20 +00:00
pablonyx
bfbc1cd954 k (#4172) 2025-03-06 18:55:12 +00:00
pablonyx
626da583aa Fix gated tenants (#4177)
* fix

* mypy .
2025-03-06 18:07:15 +00:00
pablonyx
92faca139d Fix extra tenant mystery (#4197)
* fix extra tenant mystery

* nit
2025-03-06 18:06:49 +00:00
pablonyx
cec05c5ee9 Revert "k"
This reverts commit 687122911d.
2025-03-06 09:38:31 -08:00
Richard Kuo (Danswer)
eaf054ef06 oauth router went missing? 2025-03-05 15:50:23 -08:00
pablonyx
a7a1a24658 minor nit 2025-03-05 15:35:02 -08:00
pablonyx
687122911d k 2025-03-05 15:27:14 -08:00
pablonyx
40953bd4fe Workspace configs (#4202) 2025-03-05 12:28:44 -08:00
rkuo-danswer
a7acc07e79 fix usage report pagination (#4183)
* early work in progress

* rename utility script

* move actual data seeding to a shareable function

* add test

* make the test pass with the fix

* fix comment

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-05 19:13:51 +00:00
pablonyx
b6e9e65bb8 * Replaces Amazon and Anthropic Icons with version better suitable fo… (#4190)
* * Replaces Amazon and Anthropic Icons with version better suitable for both Dark and  Light modes;
* Adds icon for DeepSeek;
* Simplify logic on icon selection;
* Adds entries for Phi-4, Claude 3.7, Ministral and Gemini 2.0 models

* nit

* k

* k

---------

Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2025-03-05 17:57:39 +00:00
pablonyx
20f2b9b2bb Add image support for search (#4090)
* add support for image search

* quick fix up

* k

* k

* k

* k

* nit

* quick fix for connector tests
2025-03-05 17:44:18 +00:00
Chris Weaver
f731beca1f Add ONYX_QUERY_HISTORY_TYPE to the dev compose files (#4196) 2025-03-05 17:34:55 +00:00
Weves
fe246aecbb Attempt to address tool happy claude 2025-03-05 09:47:27 -08:00
pablonyx
50ad066712 Better filtering (#4185)
* k

* k

* k

* k

* k
2025-03-05 04:35:50 +00:00
rkuo-danswer
870b59a1cc Bugfix/vertex crash (#4181)
* Update text embedding model to version 005 and enhance embedding retrieval process

* re

* Fix formatting issues

* Add support for Bedrock reranking provider and AWS credentials handling

* fix: improve AWS key format validation and error messages

* Fix vertex embedding model crash

* feat: add environment template for local development setup

* Add display name for Claude 3.7 Sonnet model

* Add display names for Gemini 2.0 models and update Claude 3.7 Sonnet entry

* Fix ruff errors by ensuring lines are within 130 characters

* revert to currently default onyx browser settings

* add / fix boto requirements

---------

Co-authored-by: ferdinand loesch <f.loesch@sportradar.com>
Co-authored-by: Ferdinand Loesch <ferdinandloesch@me.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-05 01:59:46 +00:00
pablonyx
5c896cb0f7 add minor fixes (#4170) 2025-03-04 20:29:28 +00:00
pablonyx
184b30643d Nit: logging adjustments (#4182) 2025-03-04 11:39:53 -08:00
pablonyx
ae585fd84c Delete all chats (#4171)
* nit

* k
2025-03-04 10:00:08 -08:00
rkuo-danswer
61e8f371b9 fix blowing up the entire task on exception and trying to reuse an in… (#4179)
* fix blowing up the entire task on exception and trying to reuse an invalid db session

* list comprehension

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-04 00:57:27 +00:00
rkuo-danswer
33cc4be492 Bugfix/GitHub validation (#4173)
* fixing unexpected errors disabling connectors

* rename UnexpectedError to UnexpectedValidationError

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-04 00:09:49 +00:00
joachim-danswer
117c8c0d78 Enable ephemeral message responses by Onyx Slack Bots (#4142)
A new setting 'is_ephemeral' has been added to the Slack channel configurations. 

Key features/effects:

  - if is_ephemeral is set for standard channel (and a Search Assistant is chosen):
     - the answer is only shown to user as an ephemeral message
     - the user has access to his private documents for a search (as the answer is only shown to them) 
     - the user has the ability to share the answer with the channel or keep private
     - a recipient list cannot be defined if the channel is set up as ephemeral
 
  - if is_ephemeral is set and DM with bot:
    - the user has access to private docs in searches
    - the message is not sent as ephemeral, as it is a 1:1 discussion with bot

 - if is_ephemeral is not set but recipient list is set:
    - the user search does *not* have access to their private documents as the information goes to the recipient list team members, and they may have different access rights

 - Overall:
     - Unless the channel is set to is_ephemeral or it is a direct conversation with the Bot, only public docs are accessible  
     - The ACL is never bypassed, also not in cases where the admin explicitly attached a document set to the bot config.
2025-03-03 15:02:21 -08:00
rkuo-danswer
9bb8cdfff1 fix web connector tests to handle new deduping (#4175)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-03 20:54:20 +00:00
Weves
a52d0d29be Small tweak to NumberInput 2025-03-03 11:20:53 -08:00
Chris Weaver
f25e1e80f6 Add option to not re-index (#4157)
* Add option to not re-index

* Add quantizaton / dimensionality override support

* Fix build / ut
2025-03-03 10:54:11 -08:00
Yuhong Sun
39fd6919ad Fix web scrolling 2025-03-03 09:00:05 -08:00
Yuhong Sun
7f0653d173 Handling of #! sites (#4169) 2025-03-03 08:18:44 -08:00
SubashMohan
e9905a398b Enhance iframe content extraction and add thresholds for JavaScript disabled scenarios (#4167) 2025-03-02 19:29:10 -08:00
Brad Slavin
3ed44e8bae Update Unstructured documentation URL to new location (#4168) 2025-03-02 19:16:38 -08:00
pablonyx
64158a5bdf silence_logs (#4165) 2025-03-02 19:00:59 +00:00
pablonyx
afb2393596 fix dark mode index attempt failure (#4163) 2025-03-02 01:23:16 +00:00
pablonyx
d473c4e876 Fix curator default persona editing (#4158)
* k

* k
2025-03-02 00:40:14 +00:00
pablonyx
692058092f fix typo 2025-03-01 13:00:07 -08:00
pablonyx
e88325aad6 bump version (#4164) 2025-03-01 01:58:45 +00:00
pablonyx
7490250e91 Fix user group edge case (#4159)
* fix user group

* k
2025-02-28 23:55:21 +00:00
pablonyx
e5369fcef8 Update warning copy (#4160)
* k

* k

* quick nit
2025-02-28 23:46:21 +00:00
Yuhong Sun
b0f00953bc Add CODEOWNERS 2025-02-28 13:57:33 -08:00
rkuo-danswer
f6a75c86c6 Bugfix/emit background error (#4156)
* print the test name when it runs

* type hints

* can't reuse session after an exception

* better logging

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-28 18:35:24 +00:00
pablonyx
ed9989282f nit- update casing enforcement on frontend 2025-02-28 10:09:06 -08:00
pablonyx
e80a0f2716 Improved google connector flow (#4155)
* fix handling

* k

* k

* fix function

* k

* k
2025-02-28 05:13:39 +00:00
rkuo-danswer
909403a648 Feature/confluence oauth (#3477)
* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* early cut at google drive oauth

* second pass

* switch to production uri's

* try handling oauth_interactive differently

* pass through client id and secret if uploaded

* fix call

* fix test

* temporarily disable check for testing

* Revert "temporarily disable check for testing"

This reverts commit 4b5a022a5f.

* support visibility in test

* missed file

* first cut at confluence oauth

* work in progress

* work in progress

* work in progress

* work in progress

* work in progress

* first cut at distributed locking

* WIP to make test work

* add some dev mode affordances and gate usage of redis behind dynamic credentials

* mypy and credentials provider fixes

* WIP

* fix created at

* fix setting initialValue on everything

* remove debugging, fix ??? some TextFormField issues

* npm fixes

* comment cleanup

* fix comments

* pin the size of the card section

* more review fixes

* more fixes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-28 03:48:51 +00:00
pablonyx
cd84b65011 quick fix (#4154) 2025-02-28 02:03:34 +00:00
pablonyx
413f21cec0 Filter assistants fix (#4153)
* k

* quick nit

* minor assistant filtering fix
2025-02-28 02:03:21 +00:00
219 changed files with 9615 additions and 3267 deletions

1
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1 @@
* @onyx-dot-app/onyx-core-team

View File

@@ -0,0 +1,125 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -0,0 +1,55 @@
"""add background_reindex_enabled field
Revision ID: b7c2b63c4a03
Revises: f11b408e39d3
Create Date: 2024-03-26 12:34:56.789012
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import EmbeddingPrecision
# revision identifiers, used by Alembic.
revision = "b7c2b63c4a03"
down_revision = "f11b408e39d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add background_reindex_enabled column with default value of True
op.add_column(
"search_settings",
sa.Column(
"background_reindex_enabled",
sa.Boolean(),
nullable=False,
server_default="true",
),
)
# Add embedding_precision column with default value of FLOAT
op.add_column(
"search_settings",
sa.Column(
"embedding_precision",
sa.Enum(EmbeddingPrecision, native_enum=False),
nullable=False,
server_default=EmbeddingPrecision.FLOAT.name,
),
)
# Add reduced_dimension column with default value of None
op.add_column(
"search_settings",
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
)
def downgrade() -> None:
# Remove the background_reindex_enabled column
op.drop_column("search_settings", "background_reindex_enabled")
op.drop_column("search_settings", "embedding_precision")
op.drop_column("search_settings", "reduced_dimension")

View File

@@ -4,7 +4,8 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -18,7 +19,26 @@ logger = setup_logger()
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
try:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
#####

View File

@@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""

View File

@@ -134,7 +134,9 @@ def fetch_chat_sessions_eagerly_by_time(
limit: int | None = 500,
initial_time: datetime | None = None,
) -> list[ChatSession]:
time_order: UnaryExpression = desc(ChatSession.time_created)
"""Sorted by oldest to newest, then by message id"""
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)
filters: list[ColumnElement | BinaryExpression] = [
@@ -147,8 +149,7 @@ def fetch_chat_sessions_eagerly_by_time(
subquery = (
db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters)
.order_by(ChatSession.id, time_order)
.distinct(ChatSession.id)
.order_by(asc_time_order)
.limit(limit)
.subquery()
)
@@ -164,7 +165,7 @@ def fetch_chat_sessions_eagerly_by_time(
ChatMessage.chat_message_feedbacks
),
)
.order_by(time_order, message_order)
.order_by(asc_time_order, message_order)
)
chat_sessions = query.all()

View File

@@ -16,13 +16,18 @@ from onyx.db.models import UsageReport
from onyx.file_store.file_store import get_default_file_store
# Gets skeletons of all message
# Gets skeletons of all messages in the given range
def get_empty_chat_messages_entries__paginated(
db_session: Session,
period: tuple[datetime, datetime],
limit: int | None = 500,
initial_time: datetime | None = None,
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
"""Returns a tuple where:
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],
end=period[1],
@@ -52,18 +57,17 @@ def get_empty_chat_messages_entries__paginated(
if len(chat_sessions) == 0:
return None, []
return chat_sessions[0].time_created, message_skeletons
return chat_sessions[-1].time_created, message_skeletons
def get_all_empty_chat_message_entries(
db_session: Session,
period: tuple[datetime, datetime],
) -> Generator[list[ChatMessageSkeleton], None, None]:
"""period is the range of time over which to fetch messages."""
initial_time: Optional[datetime] = period[0]
ind = 0
while True:
ind += 1
# iterate from oldest to newest
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
db_session,
period,

View File

@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC
# otherwise, set their role to BASIC only if they were previously a CURATOR
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -631,7 +631,16 @@ def update_user_group(
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()

View File

@@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -354,7 +359,11 @@ def confluence_doc_sync(
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "confluence", cc_pair.credential_id
)
confluence_connector.set_credentials_provider(provider)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)

View File

@@ -1,9 +1,11 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -61,13 +63,27 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
url = wiki_base.rstrip("/")
probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
confluence_client = OnyxConfluence(is_cloud, url, provider)
confluence_client._probe_connection(**probe_kwargs)
confluence_client._initialize_connection(**final_kwargs)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,

View File

@@ -32,7 +32,8 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -119,6 +119,7 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects

View File

@@ -123,7 +123,8 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth import router as oauth_router
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, oauth_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(
@@ -152,4 +152,8 @@ def get_application() -> FastAPI:
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")
return application

View File

@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger
@@ -216,7 +216,7 @@ def _handle_standard_answers(
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
@@ -231,6 +231,7 @@ def _handle_standard_answers(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
receiver_ids=receiver_ids,
)
return True

View File

@@ -1,629 +0,0 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.onyx.app/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -0,0 +1,91 @@
import base64
import uuid
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from ee.onyx.server.oauth.api_router import router
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str | None = None
if connector == DocumentSource.SLACK:
if not DEV_MODE:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.CONFLUENCE:
if not DEV_MODE:
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
session = ConfluenceCloudOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
if not DEV_MODE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
if not session:
raise HTTPException(
status_code=500,
detail=f"The document source type {connector} failed to generate an OAuth session.",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})

View File

@@ -0,0 +1,3 @@
from fastapi import APIRouter
router: APIRouter = APIRouter(prefix="/oauth")

View File

@@ -0,0 +1,362 @@
import base64
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ConfluenceCloudOAuth:
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class AccessibleResources(BaseModel):
id: str
name: str
url: str
scopes: list[str]
avatarUrl: str
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
ACCESSIBLE_RESOURCE_URL = (
"https://api.atlassian.com/oauth/token/accessible-resources"
)
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
# classic scope
"read:confluence-space.summary%20"
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence%20"
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&redirect_uri={redirect_uri}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
return session
@classmethod
def generate_finalize_url(cls, credential_id: int) -> str:
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
@router.post("/connector/confluence/callback")
def confluence_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Handles the backend logic for the frontend page that the user is redirected to
after visiting the oauth authorization url."""
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Confluence Cloud client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = ConfluenceCloudOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
else:
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
ConfluenceCloudOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
try:
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
response.text
)
except Exception:
raise RuntimeError(
"Confluence Cloud OAuth failed during code/token exchange."
)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
credential_info = CredentialBase(
credential_json={
"confluence_access_token": token_response.access_token,
"confluence_refresh_token": token_response.refresh_token,
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"expires_in": token_response.expires_in,
"scope": token_response.scope,
},
admin_public=True,
source=DocumentSource.CONFLUENCE,
name="Confluence Cloud OAuth",
)
credential = create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth completed successfully.",
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
"redirect_on_success": session.redirect_on_success,
}
)
@router.get("/connector/confluence/accessible-resources")
def confluence_oauth_accessible_resources(
credential_id: int,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Atlassian's API is weird and does not supply us with enough info to be in a
usable state after authorizing. All API's require a cloud id. We have to list
the accessible resources/sites and let the user choose which site to use."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
access_token = credential_dict["confluence_access_token"]
try:
# Exchange the authorization code for an access token
response = requests.get(
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
response.raise_for_status()
accessible_resources_data = response.json()
# Validate the list of AccessibleResources
try:
accessible_resources = [
ConfluenceCloudOAuth.AccessibleResources(**resource)
for resource in accessible_resources_data
]
except ValidationError as e:
raise RuntimeError(f"Failed to parse accessible resources: {e}")
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud get accessible resources completed successfully.",
"accessible_resources": [
resource.model_dump() for resource in accessible_resources
],
}
)
@router.post("/connector/confluence/finalize")
def confluence_oauth_finalize(
credential_id: int,
cloud_id: str,
cloud_name: str,
cloud_url: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Saves the info for the selected cloud site to the credential.
This is the final step in the confluence oauth flow where after the traditional
OAuth process, the user has to select a site to associate with the credentials.
After this, the credential is usable."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url
try:
update_credential_json(credential_id, new_credential_json, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth finalized successfully.",
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
}
)

View File

@@ -0,0 +1,229 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
else:
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -0,0 +1,197 @@
import base64
import uuid
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = SlackOAuth.REDIRECT_URI
else:
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
"team_id": team_id,
"authed_user_id": authed_user_id,
}
)

View File

@@ -48,4 +48,5 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}

View File

@@ -55,7 +55,11 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -104,14 +108,14 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
status_code=409, detail="User already belongs to an organization"
)
logger.info(f"Provisioning tenant: {tenant_id}")
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
if not create_schema_if_not_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
logger.debug(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

View File

@@ -6,7 +6,7 @@ MODEL_WARM_UP_STRING = "hi " * 512
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
class EmbeddingModelTextType:

View File

@@ -5,6 +5,7 @@ from types import TracebackType
from typing import cast
from typing import Optional
import aioboto3 # type: ignore
import httpx
import openai
import vertexai # type: ignore
@@ -28,11 +29,13 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
@@ -78,7 +81,7 @@ class CloudEmbedding:
self._closed = False
async def _embed_openai(
self, texts: list[str], model: str | None
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
@@ -91,7 +94,11 @@ class CloudEmbedding:
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(input=text_batch, model=model)
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@@ -178,17 +185,24 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
embeddings = await client.get_embeddings_async(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True, # This is the default
)
return [embedding.values for embedding in embeddings]
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
@@ -223,9 +237,10 @@ class CloudEmbedding:
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name)
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
@@ -326,6 +341,7 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -369,6 +385,7 @@ async def embed_text(
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
@@ -440,7 +457,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)
async def cohere_rerank(
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
@@ -450,6 +467,45 @@ async def cohere_rerank(
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
@@ -508,6 +564,7 @@ async def process_embed_request(
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)
@@ -564,15 +621,32 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank(
sim_scores = await cohere_rerank_api(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(

View File

@@ -70,3 +70,32 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS
return GPUStatus.NONE
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -98,8 +98,16 @@ def choose_tool(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
tools=(
[tool.tool_definition() for tool in tools] or None
if using_tool_calling_llm
else None
),
tool_choice=(
"required"
if tools and force_use_tool.force_use and using_tool_calling_llm
else None
),
structured_response_format=structured_response_format,
)

View File

@@ -523,6 +523,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if user_count == 1:
@@ -544,7 +545,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.notice(f"User {user.id} has registered.")
logger.debug(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -586,14 +587,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)

View File

@@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task(
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(cc_pair)
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(

View File

@@ -23,9 +23,9 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
@@ -61,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
@@ -406,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# check for search settings swap
with get_session_with_current_tenant() as db_session:
old_search_settings = check_index_swap(db_session=db_session)
old_search_settings = check_and_perform_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
@@ -439,6 +439,15 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
# skip non-live search settings that don't have background reindex enabled
# those should just auto-change to live shortly after creation without
# requiring any indexing till that point
if (
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -456,23 +465,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
if not should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if search_settings_instance.status.is_current():
# the indexing trigger is only checked and cleared with the current search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True

View File

@@ -346,11 +346,10 @@ def validate_indexing_fences(
return
def _should_index(
def should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -415,9 +414,9 @@ def _should_index(
):
return False
if search_settings_primary:
if search_settings_instance.status.is_current():
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
# if a manual indexing trigger is on the cc pair, honor it for live search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq

View File

@@ -11,10 +11,27 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_current_tenant() as db_session:
try:
error_message = ""
# try to write to the db, but handle IntegrityError specifically
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = (
f"Failed to create background error: {str(e)}. Original message: {message}"
)
except Exception:
pass
if not error_message:
return
# if we get here from an IntegrityError, try to write the error message to the db
# we need a new session because the first session is now invalid
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, error_message, None)
except Exception:
pass

View File

@@ -22,6 +22,7 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -92,11 +93,17 @@ def _get_connector_runner(
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except UnexpectedValidationError as e:
logger.exception(
"Unable to instantiate connector due to an unexpected temporary issue."
)
raise e
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
logger.exception("Unable to instantiate connector. Pausing until fixed.")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed. Sometimes there are cases where the connector will
# it will never succeed
# Sometimes there are cases where the connector will
# intermittently fail to initialize in which case we should pass in
# leave_connector_active=True to allow it to continue.
# For example, if there is nightly maintenance on a Confluence Server instance,

View File

@@ -756,6 +756,7 @@ def stream_chat_message_objects(
)
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
is_connected=is_connected,

View File

@@ -640,3 +640,6 @@ TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
MOCK_LLM_RESPONSE = (
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
)
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20

View File

@@ -0,0 +1,38 @@
from onyx.configs.app_configs import DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
from onyx.server.settings.store import load_settings
def get_image_extraction_and_analysis_enabled() -> bool:
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.image_extraction_and_analysis_enabled is not None:
return settings.image_extraction_and_analysis_enabled
except Exception:
pass
return False
def get_search_time_image_analysis_enabled() -> bool:
"""Get search time image analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.search_time_image_analysis_enabled is not None:
return settings.search_time_image_analysis_enabled
except Exception:
pass
return False
def get_image_analysis_max_size_mb() -> int:
"""Get image analysis max size MB setting from workspace settings or fallback to environment variable"""
try:
settings = load_settings()
if settings.image_analysis_max_size_mb is not None:
return settings.image_analysis_max_size_mb
except Exception:
pass
return DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB

View File

@@ -200,7 +200,6 @@ class AirtableConnector(LoadConnector):
return attachment_response.content
logger.error(f"Failed to refresh attachment for {filename}")
raise
attachment_content = get_attachment_with_retry(url, record_id)

View File

@@ -18,7 +18,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -310,7 +310,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected error during blob storage settings validation: {e}"
)

View File

@@ -11,17 +11,19 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import extract_text_from_confluence_html
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import attachment_to_content
from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import convert_attachment_to_content
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import process_attachment
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -33,28 +35,26 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Potential Improvements
# 1. Include attachments, etc
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
# 1. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
_PAGE_EXPANSION_FIELDS = [
"body.storage.value",
"version",
"space",
"metadata.labels",
"history.lastUpdated",
]
_ATTACHMENT_EXPANSION_FIELDS = [
"version",
"space",
"metadata.labels",
]
_RESTRICTIONS_EXPANSION_FIELDS = [
"space",
"restrictions.read.restrictions.user",
@@ -83,7 +83,13 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
class ConfluenceConnector(
LoadConnector,
PollConnector,
SlimConnector,
CredentialsConnector,
VisionEnabledConnector,
):
def __init__(
self,
wiki_base: str,
@@ -100,14 +106,24 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
) -> None:
self.wiki_base = wiki_base
self.is_cloud = is_cloud
self.space = space
self.page_id = page_id
self.index_recursively = index_recursively
self.cql_query = cql_query
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
self._fetched_titles: set[str] = set()
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
If nothing is provided, we default to fetching all pages
Only one or none of the following options should be specified so
@@ -137,6 +153,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
self.credentials_provider: CredentialsProviderInterface | None = None
self.probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
self.final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
@property
def confluence_client(self) -> OnyxConfluence:
@@ -144,15 +171,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self.credentials_provider = credentials_provider
# raises exception if there's a problem
confluence_client = OnyxConfluence(
self.is_cloud, self.wiki_base, credentials_provider
)
return None
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
self._confluence_client = confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
def _construct_page_query(
self,
@@ -160,7 +194,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
@@ -172,7 +205,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
@@ -183,11 +215,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comment in self.confluence_client.paginated_cql_retrieval(
cql=comment_cql,
expand=expand,
@@ -198,116 +229,177 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
confluence_object=comment,
fetched_titles=set(),
)
return comment_string
def _convert_object_to_document(
self, confluence_object: dict[str, Any]
) -> Document | None:
def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None:
"""
Takes in a confluence object, extracts all metadata, and converts it into a document.
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
Converts a Confluence page to a Document object.
Includes the page content, comments, and attachments.
"""
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
)
try:
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
object_text = None
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client, attachment=confluence_object
# Get the page content
page_content = extract_text_from_confluence_html(
self.confluence_client, page, self._fetched_titles
)
if object_text is None:
# This only happens for attachments that are not parseable
# Create the main section for the page content
sections = [Section(text=page_content, link=page_url)]
# Process comments if available
comment_text = self._get_comment_string_for_page_id(page_id)
if comment_text:
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
# Process attachments
if "children" in page and "attachment" in page["children"]:
attachments = self.confluence_client.get_attachments_for_page(
page_id, expand="metadata"
)
for attachment in attachments.get("results", []):
# Process each attachment
result = process_attachment(
self.confluence_client,
attachment,
page_title,
self.image_analysis_llm,
)
if result.text:
# Create a section for the attachment text
attachment_section = Section(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
image_file_name=result.file_name,
)
sections.append(attachment_section)
elif result.error:
logger.warning(
f"Error processing attachment '{attachment.get('title')}': {result.error}"
)
# Extract metadata
metadata = {}
if "space" in page:
metadata["space"] = page["space"].get("name", "")
# Extract labels
labels = []
if "metadata" in page and "labels" in page["metadata"]:
for label in page["metadata"]["labels"].get("results", []):
labels.append(label.get("name", ""))
if labels:
metadata["labels"] = labels
# Extract owners
primary_owners = []
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
primary_owners.append(BasicExpertInfo(display_name=display_name))
# Create the document
return Document(
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
metadata=metadata,
doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None,
)
except Exception as e:
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
if not self.continue_on_failure:
raise
return None
# Get space name
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": confluence_object["space"]["name"]
}
# Get labels
label_dicts = (
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
)
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
if page_labels:
doc_metadata["labels"] = page_labels
# Get last modified and author email
version_dict = confluence_object.get("version", {})
last_modified = (
datetime_from_string(version_dict.get("when"))
if version_dict.get("when")
else None
)
author_email = version_dict.get("by", {}).get("email")
title = confluence_object.get("title", "Untitled Document")
return Document(
id=object_url,
sections=[Section(link=object_url, text=object_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=title,
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author_email)] if author_email else None
),
metadata=doc_metadata,
)
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
"""
Yields batches of Documents. For each page:
- Create a Document with 1 Section for the page text/comments
- Then fetch attachments. For each attachment:
- Attempt to convert it with convert_attachment_to_content(...)
- If successful, create a new Section with the extracted text or summary.
"""
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
logger.debug(f"_fetch_document_batches: {page['id']}")
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Build doc from page
doc = self._convert_page_to_document(page)
if not doc:
continue
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment, self.image_analysis_llm
):
continue
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_context=confluence_xml,
llm=self.image_analysis_llm,
)
if response is None:
continue
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
Section(
text=content_text,
link=object_url,
image_file_name=file_storage_name,
)
)
except Exception as e:
logger.error(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if not self.continue_on_failure:
raise
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -328,55 +420,63 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
Return 'slim' docs (IDs + minimal permission data).
Does not fetch actual text. Used primarily for incremental permission sync.
"""
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
page_restrictions = page.get("restrictions")
page_space_key = page.get("space", {}).get("key")
page_ancestors = page.get("ancestors", [])
page_perm_sync_data = {
"restrictions": page_restrictions or {},
"space_key": page_space_key,
"ancestors": page_ancestors or [],
"ancestors": page_ancestors,
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
perm_sync_data=page_perm_sync_data,
)
)
# Query attachments for each page
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
if not validate_attachment_filetype(attachment):
# If you skip images, you'll skip them in the permission sync
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment, self.image_analysis_llm
):
continue
attachment_restrictions = attachment.get("restrictions")
attachment_restrictions = attachment.get("restrictions", {})
if not attachment_restrictions:
attachment_restrictions = page_restrictions
attachment_restrictions = page_restrictions or {}
attachment_space_key = attachment.get("space", {}).get("key")
if not attachment_space_key:
attachment_space_key = page_space_key
attachment_perm_sync_data = {
"restrictions": attachment_restrictions or {},
"restrictions": attachment_restrictions,
"space_key": attachment_space_key,
}
@@ -390,16 +490,16 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=attachment_perm_sync_data,
)
)
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
if callback and callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
@@ -420,11 +520,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
raise InsufficientPermissionsError(
"Insufficient permissions to access Confluence resources (HTTP 403)."
)
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected Confluence error (status={status_code}): {e}"
)
except Exception as e:
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected error while validating Confluence settings: {e}"
)

View File

@@ -1,19 +1,37 @@
import math
import io
import json
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
import bs4
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from redis import Redis
from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -22,12 +40,14 @@ logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceRateLimitError(Exception):
pass
@@ -43,124 +63,355 @@ class ConfluenceUser(BaseModel):
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
class OnyxConfluence:
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is a custom Confluence class that:
A. overrides the default Confluence class to add a custom CQL method.
B.
This is necessary because the default Confluence class does not properly support cql expansions.
All methods are automatically wrapped with handle_confluence_rate_limit.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
CREDENTIAL_PREFIX = "connector:confluence:credential"
CREDENTIAL_TTL = 300 # 5 min
def _wrap_methods(self) -> None:
def __init__(
self,
is_cloud: bool,
url: str,
credentials_provider: CredentialsProviderInterface,
) -> None:
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
self.redis_client = get_redis_client(
tenant_id=credentials_provider.get_tenant_id()
)
else:
self.static_credentials = self._credentials_provider.get_credentials()
self._confluence = Confluence(url)
self.credential_key: str = (
self.CREDENTIAL_PREFIX
+ f":credential_{self._credentials_provider.get_provider_key()}"
)
self._kwargs: Any = None
self.shared_base_kwargs = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"cloud": is_cloud,
}
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
Returns a tuple
1. The up to date credentials
2. True if the credentials were updated
This method is intended to be used within a distributed lock.
Lock, call this, update credentials if the tokens were refreshed, then release
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
wrap it with handle_confluence_rate_limit.
"""
for attr_name in dir(self):
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
setattr(
self,
attr_name,
handle_confluence_rate_limit(getattr(self, attr_name)),
# static credentials are preloaded, so no locking/redis required
if self.static_credentials:
return self.static_credentials, False
if not self.redis_client:
raise RuntimeError("self.redis_client is None")
# dynamic credentials need locking
# check redis first, then fallback to the DB
credential_raw = self.redis_client.get(self.credential_key)
if credential_raw is not None:
credential_bytes = cast(bytes, credential_raw)
credential_str = credential_bytes.decode("utf-8")
credential_json: dict[str, Any] = json.loads(credential_str)
else:
credential_json = self._credentials_provider.get_credentials()
if "confluence_refresh_token" not in credential_json:
# static credentials ... cache them permanently and return
self.static_credentials = credential_json
return credential_json, False
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID:
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!")
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET:
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!")
# check if we should refresh tokens. we're deciding to refresh halfway
# to expiration
now = datetime.now(timezone.utc)
created_at = datetime.fromisoformat(credential_json["created_at"])
expires_in: int = credential_json["expires_in"]
renew_at = created_at + timedelta(seconds=expires_in // 2)
if now <= renew_at:
# cached/current credentials are reasonably up to date
return credential_json, False
# we need to refresh
logger.info("Renewing Confluence Cloud credentials...")
new_credentials = confluence_refresh_tokens(
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
credential_json["cloud_id"],
credential_json["confluence_refresh_token"],
)
# store the new credentials to redis and to the db thru the provider
# redis: we use a 5 min TTL because we are given a 10 minute grace period
# when keys are rotated. it's easier to expire the cached credentials
# reasonably frequently rather than trying to handle strong synchronization
# between the db and redis everywhere the credentials might be updated
new_credential_str = json.dumps(new_credentials)
self.redis_client.set(
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
)
self._credentials_provider.set_credentials(new_credentials)
return new_credentials, True
@staticmethod
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
oauth2_dict: dict[str, Any] = {}
if "confluence_refresh_token" in credentials:
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
oauth2_dict["token"] = {}
oauth2_dict["token"]["access_token"] = credentials[
"confluence_access_token"
]
return oauth2_dict
def _probe_connection(
self,
**kwargs: Any,
) -> None:
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Probing Confluence with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
credentials
)
url = (
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
)
confluence_client_with_minimal_retries = Confluence(
url=url, oauth2=oauth2_dict, **merged_kwargs
)
else:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**merged_kwargs,
)
else:
confluence_client_with_minimal_retries = Confluence(
url=url,
token=credentials["confluence_access_token"],
**merged_kwargs,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {url}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
logger.info("Confluence probe succeeded.")
def _initialize_connection(
self,
**kwargs: Any,
) -> None:
"""Called externally to init the connection in a thread safe manner."""
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
self._confluence = self._initialize_connection_helper(
credentials, **merged_kwargs
)
self._kwargs = merged_kwargs
def _initialize_connection_helper(
self,
credentials: dict[str, Any],
**kwargs: Any,
) -> Confluence:
"""Called internally to init the connection. Distributed locking
to prevent multiple threads from modifying the credentials
must be handled around this function."""
confluence = None
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info("Connecting to Confluence with Personal Access Token.")
if self._is_cloud:
confluence = Confluence(
url=self._url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**kwargs,
)
else:
confluence = Confluence(
url=self._url,
token=credentials["confluence_access_token"],
**kwargs,
)
return confluence
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def _make_rate_limited_confluence_method(
self, name: str, credential_provider: CredentialsProviderInterface | None
) -> Callable[..., Any]:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
try:
if credential_provider:
with credential_provider:
credentials, renewed = self._renew_credentials()
if renewed:
self._confluence = self._initialize_connection_helper(
credentials, **self._kwargs
)
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
else:
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return wrapped_call
# def _wrap_methods(self) -> None:
# """
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
# wrap it with handle_confluence_rate_limit.
# """
# for attr_name in dir(self):
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
# setattr(
# self,
# attr_name,
# handle_confluence_rate_limit(getattr(self, attr_name)),
# )
# def _ensure_token_valid(self) -> None:
# if self._token_is_expired():
# self._refresh_token()
# # Re-init the Confluence client with the originally stored args
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
def __getattr__(self, name: str) -> Any:
"""Dynamically intercept attribute/method access."""
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
# If it's not a method, just return it after ensuring token validity
if not callable(attr):
return attr
# skip methods that start with "_"
if name.startswith("_"):
return attr
# wrap the method with our retry handler
rate_limited_method: Callable[
..., Any
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
return rate_limited_method(*args, **kwargs)
return wrapped_method
def _paginate_url(
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
@@ -507,63 +758,212 @@ class OnyxConfluence(Confluence):
return response
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
return extracted_text
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
try:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
except Exception as e:
raise ConnectorValidationError(str(e))
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)
return format_document_soup(soup)

View File

@@ -1,239 +1,280 @@
import io
import math
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
from onyx.configs.constants import FileOrigin
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import PGFileStore
from onyx.db.pg_file_store import create_populate_lobj
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
_USER_EMAIL_CACHE: dict[str, str | None] = {}
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
def get_user_email_from_username__server(
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
def validate_attachment_filetype(
attachment: dict[str, Any], llm: LLM | None = None
) -> bool:
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def extract_text_from_confluence_html(
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
Validates if the attachment is a supported file type.
If LLM is provided, also checks if it's an image that can be processed.
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
attachment.get("metadata", {})
media_type = attachment.get("metadata", {}).get("mediaType", "")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
if media_type.startswith("image/"):
return llm is not None and is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
return extension in ["pdf", "doc", "docx", "txt", "md", "rtf"]
class AttachmentProcessingResult(BaseModel):
"""
A container for results after processing a Confluence attachment.
'text' is the textual content of the attachment.
'file_name' is the final file name used in PGFileStore to store the content.
'error' holds an exception or string if something failed.
"""
text: str | None
file_name: str | None
error: str | None = None
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
return None
return resp.content
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
return attachment["metadata"]["mediaType"] not in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]
def attachment_to_content(
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
page_context: str,
llm: LLM | None,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
download_link = confluence_client.url + attachment["_links"]["download"]
# Validate the attachment type
if not validate_attachment_filetype(attachment, llm):
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
# Download the attachment
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to download attachment"
)
# Process image attachments with LLM if available
if media_type.startswith("image/") and llm:
return _process_image_attachment(
confluence_client, attachment, page_context, llm, raw_bytes, media_type
)
# Process document attachments
try:
text = extract_file_text(
file=BytesIO(raw_bytes),
file_name=attachment["title"],
)
# Skip if the text is too long
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {len(text)} chars",
)
return AttachmentProcessingResult(text=text, file_name=None, error=None)
except Exception as e:
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to extract text: {e}"
)
except Exception as e:
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to process attachment: {e}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
def _process_image_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM,
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it and generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_name=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
llm=llm,
file_origin=FileOrigin.CONNECTOR,
)
return AttachmentProcessingResult(
text=section.text, file_name=file_name, error=None
)
except Exception as e:
msg = f"Image summarization failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
def _process_text_attachment(
attachment: dict[str, Any],
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process a text-based attachment by extracting its content."""
try:
extracted_text = extract_file_text(
io.BytesIO(raw_bytes),
file_name=attachment["title"],
break_on_unprocessable=False,
)
return None
except Exception as e:
msg = f"Failed to extract text for '{attachment['title']}': {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
# Check length constraints
if extracted_text is None or len(extracted_text) == 0:
msg = f"No text extracted for {attachment['title']}"
logger.warning(msg)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
msg = (
f"Skipping attachment {attachment['title']} due to char count "
f"({len(extracted_text)} > {CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD})"
)
logger.warning(msg)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
# Save the attachment
try:
with get_session_with_current_tenant() as db_session:
saved_record = save_bytes_to_pgfilestore(
db_session=db_session,
raw_bytes=raw_bytes,
media_type=media_type,
identifier=attachment["id"],
display_name=attachment["title"],
)
except Exception as e:
msg = f"Failed to save attachment '{attachment['title']}' to PG: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(
text=extracted_text, file_name=None, error=msg
)
return AttachmentProcessingResult(
text=extracted_text, file_name=saved_record.file_name, error=None
)
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM | None,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
1. Validates attachment type
2. Extracts or summarizes content
3. Returns (content_text, stored_file_name) or None if we should skip it
"""
media_type = attachment["metadata"]["mediaType"]
# Quick check for unsupported types:
if media_type.startswith("video/") or media_type == "application/gliffy+json":
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}"
)
return None
return extracted_text
result = process_attachment(confluence_client, attachment, page_context, llm)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"
)
return None
# Return the text and the file name
return result.text, result.file_name
def build_confluence_document_id(
@@ -254,23 +295,6 @@ def build_confluence_document_id(
return f"{base_url}{content_url}"
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachments in use
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
referenced_attachment_filenames = []
soup = bs4.BeautifulSoup(page_text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
referenced_attachment_filenames.append(attachment.attrs["ri:filename"])
return referenced_attachment_filenames
def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime.fromisoformat(datetime_string)
@@ -284,6 +308,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
return datetime_object
def confluence_refresh_tokens(
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
) -> dict[str, Any]:
# rotate the refresh and access token
# Note that access tokens are only good for an hour in confluence cloud,
# so we're going to have problems if the connector runs for longer
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
response = requests.post(
CONFLUENCE_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
},
)
try:
token_response = TokenResponse.model_validate_json(response.text)
except Exception:
raise RuntimeError("Confluence Cloud token refresh failed.")
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
new_credentials: dict[str, Any] = {}
new_credentials["confluence_access_token"] = token_response.access_token
new_credentials["confluence_refresh_token"] = token_response.refresh_token
new_credentials["created_at"] = now.isoformat()
new_credentials["expires_at"] = expires_at.isoformat()
new_credentials["expires_in"] = token_response.expires_in
new_credentials["scope"] = token_response.scope
new_credentials["cloud_id"] = cloud_id
return new_credentials
F = TypeVar("F", bound=Callable[..., Any])
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
@@ -311,3 +466,37 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)
def attachment_to_file_record(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
db_session: Session,
) -> tuple[PGFileStore, bytes]:
"""Save an attachment to the file store and return the file record."""
download_link = _attachment_to_download_link(confluence_client, attachment)
image_data = confluence_client.get(
download_link, absolute=True, not_json_response=True
)
# Save image to file store
file_name = f"confluence_attachment_{attachment['id']}"
lobj_oid = create_populate_lobj(BytesIO(image_data), db_session)
pgfilestore = upsert_pgfilestore(
file_name=file_name,
display_name=attachment["title"],
file_origin=FileOrigin.OTHER,
file_type=attachment["metadata"]["mediaType"],
lobj_oid=lobj_oid,
db_session=db_session,
commit=True,
)
return pgfilestore, image_data
def _attachment_to_download_link(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> str:
"""Extracts the download link to images."""
return confluence_client.url + attachment["_links"]["download"]

View File

@@ -0,0 +1,135 @@
import uuid
from types import TracebackType
from typing import Any
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import Credential
from onyx.redis.redis_pool import get_redis_client
class OnyxDBCredentialsProvider(
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
):
"""Implementation to allow the connector to callback and update credentials in the db.
Required in cases where credentials can rotate while the connector is running.
"""
LOCK_TTL = 900 # TTL of the lock
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_id = credential_id
self.redis_client = get_redis_client(tenant_id=tenant_id)
# lock used to prevent overlapping renewal of credentials
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
def __enter__(self) -> "OnyxDBCredentialsProvider":
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
if not acquired:
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Release the lock when exiting the context."""
if self._lock and self._lock.owned():
self._lock.release()
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return str(self._credential_id)
def get_credentials(self) -> dict[str, Any]:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
credential = db_session.execute(
select(Credential).where(Credential.id == self._credential_id)
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
try:
credential = db_session.execute(
select(Credential)
.where(Credential.id == self._credential_id)
.with_for_update()
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
db_session.commit()
except Exception:
db_session.rollback()
raise
def is_dynamic(self) -> bool:
return True
class OnyxStaticCredentialsProvider(
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "OnyxStaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False

View File

@@ -14,12 +14,15 @@ class ConnectorValidationError(ValidationError):
super().__init__(self.message)
class UnexpectedError(ValidationError):
class UnexpectedValidationError(ValidationError):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
Currently, unexpected validation errors are defined as transient and should not be
used to disable the connector.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):

View File

@@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
@@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.models import Credential
from shared_configs.contextvars import get_current_tenant_id
class ConnectorMissingException(Exception):
@@ -167,10 +170,17 @@ def instantiate_connector(
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
if isinstance(connector, CredentialsConnector):
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), str(source), credential.id
)
connector.set_credentials_provider(provider)
else:
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
return connector

View File

@@ -10,22 +10,23 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -35,81 +36,115 @@ def _read_files_and_metadata(
file_name: str,
db_session: Session,
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
"""Reads the file into IO, in the case of a zip file, yields each individual
file contained within, also includes the metadata dict if packaged in the zip"""
"""
Reads the file from Postgres. If the file is a .zip, yields subfiles.
"""
extension = get_file_ext(file_name)
metadata: dict[str, Any] = {}
directory_path = os.path.dirname(file_name)
# Read file from Postgres store
file_content = get_default_file_store(db_session).read_file(file_name, mode="b")
# If it's a zip, expand it
if extension == ".zip":
for file_info, file, metadata in load_files_from_zip(
for file_info, subfile, metadata in load_files_from_zip(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), file, metadata
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
def _create_image_section(
llm: LLM | None,
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
idx: int = 0,
) -> tuple[Section, str | None]:
"""
Create a Section object for a single image and store the image in PGFileStore.
If summarization is enabled and we have an LLM, summarize the image.
Returns:
tuple: (Section object, file_name in PGFileStore or None if storage failed)
"""
# Create a unique file name for the embedded image
file_name = f"{parent_file_name}_embedded_{idx}"
# Use the standardized utility to store the image and create a section
return store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
llm=llm,
file_origin=FileOrigin.OTHER,
)
def _process_file(
file_name: str,
file: IO[Any],
metadata: dict[str, Any] | None = None,
pdf_pass: str | None = None,
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
llm: LLM | None,
) -> list[Document]:
"""
Processes a single file, returning a list of Documents (typically one).
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
"""
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
# Fetch the DB record so we know the ID for internal URL
pg_record = get_pgfilestore_by_file_name(file_name=file_name, db_session=db_session)
if not pg_record:
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
file_metadata: dict[str, Any] = {}
if is_text_file_extension(file_name):
encoding = detect_encoding(file)
file_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
if not is_valid_file_ext(extension):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
return []
# Using the PDF reader function directly to pass in password cleanly
elif extension == ".pdf" and pdf_pass is not None:
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
# Prepare doc metadata
if metadata is None:
metadata = {}
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
else:
file_content_raw = extract_file_text(
file=file,
file_name=file_name,
break_on_unprocessable=True,
)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# add a prefix to avoid conflicts with other connectors
doc_id = f"FILE_CONNECTOR__{file_name}"
if metadata:
doc_id = metadata.get("document_id") or doc_id
# If this is set, we will show this in the UI as the "name" of the file
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
file_name
)
title = (
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
)
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
# Timestamps
current_datetime = datetime.now(timezone.utc)
time_updated = metadata.get("time_updated", current_datetime)
if isinstance(time_updated, str):
time_updated = time_str_to_utc(time_updated)
dt_str = all_metadata.get("doc_updated_at")
dt_str = metadata.get("doc_updated_at")
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
# Metadata tags separate from the Onyx specific fields
# Collect owners
p_owner_names = metadata.get("primary_owners")
s_owner_names = metadata.get("secondary_owners")
p_owners = (
[BasicExpertInfo(display_name=name) for name in p_owner_names]
if p_owner_names
else None
)
s_owners = (
[BasicExpertInfo(display_name=name) for name in s_owner_names]
if s_owner_names
else None
)
# Additional tags we store as doc metadata
metadata_tags = {
k: v
for k, v in all_metadata.items()
for k, v in metadata.items()
if k
not in [
"document_id",
@@ -122,77 +157,142 @@ def _process_file(
"file_display_name",
"title",
"connector_type",
"pdf_password",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
[BasicExpertInfo(display_name=name) for name in p_owner_names]
if p_owner_names
else None
)
s_owners = (
[BasicExpertInfo(display_name=name) for name in s_owner_names]
if s_owner_names
else None
source_type_str = metadata.get("connector_type")
source_type = (
DocumentSource(source_type_str) if source_type_str else DocumentSource.FILE
)
doc_id = metadata.get("document_id") or f"FILE_CONNECTOR__{file_name}"
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
if extension in IMAGE_EXTENSIONS:
# Summarize or produce empty doc
image_data = file.read()
image_section, _ = _create_image_section(
llm, image_data, db_session, pg_record.file_name, title
)
return [
Document(
id=doc_id,
sections=[image_section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
# (For example .docx with inline images).
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
)
# Build sections: first the text as a single Section
sections = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(Section(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image summary
image_section, _ = _create_image_section(
llm,
img_data,
db_session,
pg_record.file_name,
f"{title} - image {idx}",
idx,
)
sections.append(image_section)
return [
Document(
id=doc_id,
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=source_type or DocumentSource.FILE,
sections=sections,
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
# currently metadata just houses tags, other stuff like owners / updated at have dedicated fields
metadata=metadata_tags,
)
]
class LocalFileConnector(LoadConnector):
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
"""
Connector that reads files from Postgres and yields Documents, including
optional embedded image extraction.
"""
def __init__(
self,
file_locations: list[Path | str],
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.file_locations = [str(loc) for loc in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
return None
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Iterates over each file path, fetches from Postgres, tries to parse text
or images, and yields Document batches.
"""
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
file_name=str(file_path), db_session=db_session
files_iter = _read_files_and_metadata(
file_name=file_path,
db_session=db_session,
)
for file_name, file, metadata in files:
for actual_file_name, file, metadata in files_iter:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
new_docs = _process_file(
file_name=actual_file_name,
file=file,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
llm=self.image_analysis_llm,
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
@@ -201,7 +301,7 @@ class LocalFileConnector(LoadConnector):
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
connector.load_credentials({"pdf_password": os.environ["PDF_PASSWORD"]})
document_batches = connector.load_from_state()
print(next(document_batches))
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
doc_batches = connector.load_from_state()
for batch in doc_batches:
print("BATCH:", batch)

View File

@@ -20,7 +20,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repo_name: str | None = None,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repo_name = repo_name
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
@@ -157,11 +157,42 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logger.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
@@ -189,11 +220,17 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
for repo in repos:
if self.include_prs:
@@ -268,11 +305,48 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
@@ -284,7 +358,7 @@ class GithubConnector(LoadConnector, PollConnector):
user.get_repos().totalCount # Just check if we can access repos
except RateLimitExceededException:
raise UnexpectedError(
raise UnexpectedValidationError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
@@ -298,10 +372,15 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
@@ -310,6 +389,7 @@ class GithubConnector(LoadConnector, PollConnector):
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
@@ -321,7 +401,7 @@ if __name__ == "__main__":
connector = GithubConnector(
repo_owner=os.environ["REPO_OWNER"],
repo_name=os.environ["REPO_NAME"],
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}

View File

@@ -4,14 +4,12 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -36,7 +34,6 @@ from onyx.connectors.google_utils.shared_constants import (
)
from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from onyx.connectors.google_utils.shared_constants import SCOPE_DOC_URL
from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from onyx.connectors.google_utils.shared_constants import USER_FIELDS
from onyx.connectors.interfaces import GenerateDocumentsOutput
@@ -46,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -66,7 +65,10 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any, primary_admin_email: str, file: dict[str, Any]
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
image_analysis_llm: LLM | None,
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
@@ -75,11 +77,14 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
)
def _process_files_batch(
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
files: list[GoogleDriveFileType],
convert_func: Callable[[GoogleDriveFileType], Any],
batch_size: int,
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
@@ -111,7 +116,9 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
class GoogleDriveConnector(
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
):
def __init__(
self,
include_shared_drives: bool = False,
@@ -129,23 +136,23 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
continue_on_failure: bool | None = None,
) -> None:
# Check for old input parameters
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
if folder_paths is not None:
logger.warning(
"The 'folder_paths' parameter is deprecated. Use 'shared_folder_urls' instead."
)
raise ConnectorValidationError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
if include_shared is not None:
logger.warning(
"The 'include_shared' parameter is deprecated. Use 'include_files_shared_with_me' instead."
)
if follow_shortcuts is not None:
logger.warning("The 'follow_shortcuts' parameter is deprecated.")
if only_org_public is not None:
logger.warning("The 'only_org_public' parameter is deprecated.")
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
if (
not include_shared_drives
@@ -237,6 +244,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
credentials=credentials,
source=DocumentSource.GOOGLE_DRIVE,
)
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:
@@ -523,37 +531,53 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
# Create a larger process pool for file conversion
convert_func = partial(
_convert_single_file, self.creds, self.primary_admin_email
)
# Process files in larger batches
LARGE_BATCH_SIZE = self.batch_size * 4
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
if (
file.get("size")
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
):
logger.warning(
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
)
continue
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
files_to_process = []
# Process any remaining files
if files_to_process:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
)
# Fetch files in batches
files_batch: list[GoogleDriveFileType] = []
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_batch.append(file)
if len(files_batch) >= self.batch_size:
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
logger.error(f"Error converting file: {e}")
if documents:
yield documents
files_batch = []
# Process any remaining files
if files_batch:
futures = [executor.submit(convert_func, file) for file in files_batch]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
logger.error(f"Error converting file: {e}")
if documents:
yield documents
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive()

View File

@@ -9,7 +9,7 @@ from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.configs.constants import FileOrigin
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
@@ -21,32 +21,88 @@ from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.file_processing.extract_file_text import docx_to_text
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
) -> str:
"""
Summarize the given image using the provided LLM.
"""
if not image_analysis_llm:
return ""
return (
summarize_image_with_error_handling(
llm=image_analysis_llm,
image_data=image_data,
context_name=image_name,
)
or ""
)
def is_gdrive_image_mime_type(mime_type: str) -> bool:
"""
Return True if the mime_type is a common image type in GDrive.
(e.g. 'image/png', 'image/jpeg')
"""
return is_valid_image_type(mime_type)
def _extract_sections_basic(
file: dict[str, str], service: GoogleDriveService
file: dict[str, str],
service: GoogleDriveService,
image_analysis_llm: LLM | None = None,
) -> list[Section]:
"""
Extends the existing logic to handle either a docx with embedded images
or standalone images (PNG, JPG, etc).
"""
mime_type = file["mimeType"]
link = file["webViewLink"]
file_name = file.get("name", file["id"])
supported_file_types = set(item.value for item in GDriveMimeType)
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
if is_gdrive_image_mime_type(mime_type):
try:
response = service.files().get_media(fileId=file["id"]).execute()
with get_session_with_current_tenant() as db_session:
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file["id"],
display_name=file_name,
media_type=mime_type,
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
return [section]
except Exception as e:
logger.warning(f"Failed to fetch or summarize image: {e}")
return [
Section(
link=link,
text="",
image_file_name=link,
)
]
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
@@ -185,45 +241,63 @@ def _extract_sections_basic(
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
text_data = (
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
)
return [Section(link=link, text=text_data)]
# ---------------------------
# Word, PowerPoint, PDF files
if mime_type in [
elif mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
response_bytes = service.files().get_media(fileId=file["id"]).execute()
# Optionally use Unstructured
if get_unstructured_api_key():
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]
text = unstructured_to_text(
file=io.BytesIO(response_bytes),
file_name=file_name,
)
return [Section(link=link, text=text)]
if mime_type == GDriveMimeType.WORD_DOC.value:
return [
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
]
# Use docx_to_text_and_images to get text plus embedded images
text, embedded_images = docx_to_text_and_images(
file=io.BytesIO(response_bytes),
)
sections = []
if text.strip():
sections.append(Section(link=link, text=text.strip()))
# Process each embedded image using the standardized function
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(
embedded_images, start=1
):
# Create a unique identifier for the embedded image
embedded_id = f"{file['id']}_embedded_{idx}"
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=embedded_id,
display_name=img_name or f"{file_name} - image {idx}",
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
sections.append(section)
return sections
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
text_data = pptx_to_text(io.BytesIO(response_bytes))
return [Section(link=link, text=text_data)]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
@@ -231,7 +305,8 @@ def _extract_sections_basic(
logger.error(error_message)
raise ValueError(error_message)
except Exception:
except Exception as e:
logger.exception(f"Error extracting sections from file: {e}")
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
@@ -239,74 +314,62 @@ def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
image_analysis_llm: LLM | None,
) -> Document | None:
"""
Main entry point for converting a Google Drive file => Document object.
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
"""
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
# Skip files that are folders
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
logger.info("Ignoring Drive Folder Filetype")
# skip shortcuts or folders
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
logger.info("Skipping shortcut/folder.")
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[Section] = []
# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
# get_document_sections is the advanced approach for Google Docs
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
f"Failed to pull google doc sections from '{file['name']}': {e}. "
"Falling back to basic extraction."
)
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
# If not a doc, or if we failed above, do our 'basic' approach
if not sections:
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
if not sections:
return None
doc_id = file["webViewLink"]
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
)
return Document(
id=file["webViewLink"],
id=doc_id,
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
doc_updated_at=updated_time,
metadata={}, # or any metadata from 'file'
additional_info=file.get("id"),
)
except Exception as e:
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
except Exception as e:
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
# Skip files that are folders or shortcuts
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file["webViewLink"],
perm_sync_data={

View File

@@ -1,7 +1,10 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeVar
from pydantic import BaseModel
@@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
raise NotImplementedError
T = TypeVar("T", bound="CredentialsProviderInterface")
class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def __enter__(self) -> T:
raise NotImplementedError
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def get_tenant_id(self) -> str | None:
raise NotImplementedError
@abc.abstractmethod
def get_provider_key(self) -> str:
"""a unique key that the connector can use to lock around a credential
that might be used simultaneously.
Will typically be the credential id, but can also just be something random
in cases when there is nothing to lock (aka static credentials)
"""
raise NotImplementedError
@abc.abstractmethod
def get_credentials(self) -> dict[str, Any]:
raise NotImplementedError
@abc.abstractmethod
def set_credentials(self, credential_json: dict[str, Any]) -> None:
raise NotImplementedError
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
needs to use the locking features of the credentials provider to operate
correctly.
If static, the client can simply reference the credentials once and use them
through the entire indexing run.
"""
raise NotImplementedError
class CredentialsConnector(BaseConnector):
"""Implement this if the connector needs to be able to read and write credentials
on the fly. Typically used with shared credentials/tokens that might be renewed
at any time."""
@abc.abstractmethod
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@@ -28,7 +28,8 @@ class ConnectorMissingCredentialError(PermissionError):
class Section(BaseModel):
text: str
link: str | None
link: str | None = None
image_file_name: str | None = None
class BasicExpertInfo(BaseModel):

View File

@@ -19,7 +19,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -671,12 +671,12 @@ class NotionConnector(LoadConnector, PollConnector):
"Please try again later."
)
else:
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
) from http_err
except Exception as exc:
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected error during Notion settings validation: {exc}"
)

View File

@@ -21,7 +21,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -702,7 +702,9 @@ class SlackConnector(SlimConnector, CheckpointConnector):
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
raise UnexpectedValidationError(
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified, verify each is accessible
if self.channels:
@@ -740,13 +742,13 @@ class SlackConnector(SlimConnector, CheckpointConnector):
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except ConnectorValidationError as e:
raise e
except Exception as e:
raise UnexpectedError(
raise UnexpectedValidationError(
f"Unexpected error during Slack settings validation: {e}"
)

View File

@@ -72,6 +72,7 @@ def make_slack_api_rate_limited(
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call

View File

@@ -16,7 +16,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -302,7 +302,7 @@ class TeamsConnector(LoadConnector, PollConnector):
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()

View File

@@ -0,0 +1,45 @@
"""
Mixin for connectors that need vision capabilities.
"""
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
class VisionEnabledConnector:
"""
Mixin for connectors that need vision capabilities.
This mixin provides a standard way to initialize a vision-capable LLM
for image analysis during indexing.
Usage:
class MyConnector(LoadConnector, VisionEnabledConnector):
def __init__(self, ...):
super().__init__(...)
self.initialize_vision_llm()
"""
def initialize_vision_llm(self) -> None:
"""
Initialize a vision-capable LLM if enabled by configuration.
Sets self.image_analysis_llm to the LLM instance or None if disabled.
"""
self.image_analysis_llm: LLM | None = None
if get_image_extraction_and_analysis_enabled():
try:
self.image_analysis_llm = get_default_llm_with_vision()
if self.image_analysis_llm is None:
logger.warning(
"No LLM with vision found; image summarization will be disabled"
)
except Exception as e:
logger.warning(
f"Failed to initialize vision LLM due to an error: {str(e)}. "
"Image summarization will be disabled."
)
self.image_analysis_llm = None

View File

@@ -28,7 +28,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
@@ -42,6 +42,10 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
# Threshold for determining when to replace vs append iframe content
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
@@ -138,7 +142,8 @@ def get_internal_links(
# Account for malformed backslashes in URLs
href = href.replace("\\", "/")
if should_ignore_pound and "#" in href:
# "#!" indicates the page is using a hashbang URL, which is a client-side routing technique
if should_ignore_pound and "#" in href and "#!" not in href:
href = href.split("#")[0]
if not is_valid_url(href):
@@ -152,6 +157,7 @@ def get_internal_links(
def start_playwright() -> Tuple[Playwright, BrowserContext]:
playwright = sync_playwright().start()
browser = playwright.chromium.launch(headless=True)
context = browser.new_context()
@@ -288,6 +294,7 @@ class WebConnector(LoadConnector):
and converts them into documents"""
visited_links: set[str] = set()
to_visit: list[str] = self.to_visit_list
content_hashes = set()
if not to_visit:
raise ValueError("No URLs to visit")
@@ -302,40 +309,41 @@ class WebConnector(LoadConnector):
playwright, context = start_playwright()
restart_playwright = False
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
initial_url = to_visit.pop()
if initial_url in visited_links:
continue
visited_links.add(current_url)
visited_links.add(initial_url)
try:
protected_url_check(current_url)
protected_url_check(initial_url)
except Exception as e:
last_error = f"Invalid URL {current_url} due to {e}"
last_error = f"Invalid URL {initial_url} due to {e}"
logger.warning(last_error)
continue
logger.info(f"Visiting {current_url}")
index = len(visited_links)
logger.info(f"{index}: Visiting {initial_url}")
try:
check_internet_connection(current_url)
check_internet_connection(initial_url)
if restart_playwright:
playwright, context = start_playwright()
restart_playwright = False
if current_url.split(".")[-1] == "pdf":
if initial_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(current_url)
page_text, metadata = read_pdf_file(
response = requests.get(initial_url)
page_text, metadata, images = read_pdf_file(
file=io.BytesIO(response.content)
)
last_modified = response.headers.get("Last-Modified")
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
id=initial_url,
sections=[Section(link=initial_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split("/")[-1],
semantic_identifier=initial_url.split("/")[-1],
metadata=metadata,
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@@ -347,21 +355,29 @@ class WebConnector(LoadConnector):
continue
page = context.new_page()
page_response = page.goto(current_url)
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
page_response = page.goto(
initial_url,
timeout=30000, # 30 seconds
)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
else None
)
final_page = page.url
if final_page != current_url:
logger.info(f"Redirected to {final_page}")
protected_url_check(final_page)
current_url = final_page
if current_url in visited_links:
logger.info("Redirected page already indexed")
final_url = page.url
if final_url != initial_url:
protected_url_check(final_url)
initial_url = final_url
if initial_url in visited_links:
logger.info(
f"{index}: {initial_url} redirected to {final_url} - already indexed"
)
continue
visited_links.add(current_url)
logger.info(f"{index}: {initial_url} redirected to {final_url}")
visited_links.add(initial_url)
if self.scroll_before_scraping:
scroll_attempts = 0
@@ -379,26 +395,58 @@ class WebConnector(LoadConnector):
soup = BeautifulSoup(content, "html.parser")
if self.recursive:
internal_links = get_internal_links(base_url, current_url, soup)
internal_links = get_internal_links(base_url, initial_url, soup)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
if page_response and str(page_response.status)[0] in ("4", "5"):
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
logger.info(last_error)
continue
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
"""For websites containing iframes that need to be scraped,
the code below can extract text from within these iframes.
"""
logger.debug(
f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}"
)
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
iframe_count = page.frame_locator("iframe").locator("html").count()
if iframe_count > 0:
iframe_texts = (
page.frame_locator("iframe")
.locator("html")
.all_inner_texts()
)
document_text = "\n".join(iframe_texts)
""" 700 is the threshold value for the length of the text extracted
from the iframe based on the issue faced """
if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD:
parsed_html.cleaned_text = document_text
else:
parsed_html.cleaned_text += "\n" + document_text
# Sometimes pages with #! will serve duplicate content
# There are also just other ways this can happen
hashed_text = hash((parsed_html.title, parsed_html.cleaned_text))
if hashed_text in content_hashes:
logger.info(
f"{index}: Skipping duplicate title + content for {initial_url}"
)
continue
content_hashes.add(hashed_text)
doc_batch.append(
Document(
id=current_url,
id=initial_url,
sections=[
Section(link=current_url, text=parsed_html.cleaned_text)
Section(link=initial_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or current_url,
semantic_identifier=parsed_html.title or initial_url,
metadata={},
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@@ -410,7 +458,7 @@ class WebConnector(LoadConnector):
page.close()
except Exception as e:
last_error = f"Failed to fetch '{current_url}': {e}"
last_error = f"Failed to fetch '{initial_url}': {e}"
logger.exception(last_error)
playwright.stop()
restart_playwright = True
@@ -481,7 +529,9 @@ class WebConnector(LoadConnector):
)
else:
# Could be a 5xx or another error, treat as unexpected
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
raise UnexpectedValidationError(
f"Unexpected error validating '{test_url}': {e}"
)
if __name__ == "__main__":

View File

@@ -76,6 +76,10 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,

View File

@@ -1,12 +1,17 @@
import base64
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
import numpy
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from onyx.context.search.enums import LLMEvaluationType
@@ -18,11 +23,15 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import RerankingModel
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
@@ -30,6 +39,124 @@ from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.timing import log_function_time
def update_image_sections_with_query(
sections: list[InferenceSection],
query: str,
llm: LLM,
) -> None:
"""
For each chunk in each section that has an image URL, call an LLM to produce
a new 'content' string that directly addresses the user's query about that image.
This implementation uses parallel processing for efficiency.
"""
logger = setup_logger()
logger.debug(f"Starting image section update with query: {query}")
chunks_with_images = []
for section in sections:
for chunk in section.chunks:
if chunk.image_file_name:
chunks_with_images.append(chunk)
if not chunks_with_images:
logger.debug("No images to process in the sections")
return # No images to process
logger.info(f"Found {len(chunks_with_images)} chunks with images to process")
def process_image_chunk(chunk: InferenceChunk) -> tuple[str, str]:
try:
logger.debug(
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_name}"
)
with get_session_with_current_tenant() as db_session:
file_record = get_default_file_store(db_session).read_file(
cast(str, chunk.image_file_name), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_name}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(
f"Successfully loaded image data for {chunk.image_file_name}"
)
messages: list[BaseMessage] = [
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),
HumanMessage(
content=[
{
"type": "text",
"text": (
f"The user's question is: '{query}'. "
"Please analyze the following image in that context:\n"
),
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}",
},
},
]
),
]
raw_response = llm.invoke(messages)
answer_text = message_to_string(raw_response).strip()
return (
chunk.unique_id,
answer_text if answer_text else "No relevant info found.",
)
except Exception:
logger.exception(
f"Error updating image section with query source image url: {chunk.image_file_name}"
)
return chunk.unique_id, "Error analyzing image."
image_processing_tasks = [
FunctionCall(process_image_chunk, (chunk,)) for chunk in chunks_with_images
]
logger.info(
f"Starting parallel processing of {len(image_processing_tasks)} image tasks"
)
image_processing_results = run_functions_in_parallel(image_processing_tasks)
logger.info(
f"Completed parallel processing with {len(image_processing_results)} results"
)
# Create a mapping of chunk IDs to their processed content
chunk_id_to_content = {}
success_count = 0
for task_id, result in image_processing_results.items():
if result:
chunk_id, content = result
chunk_id_to_content[chunk_id] = content
success_count += 1
else:
logger.error(f"Task {task_id} failed to return a valid result")
logger.info(
f"Successfully processed {success_count}/{len(image_processing_results)} images"
)
# Update the chunks with the processed content
updated_count = 0
for section in sections:
for chunk in section.chunks:
if chunk.unique_id in chunk_id_to_content:
chunk.content = chunk_id_to_content[chunk.unique_id]
updated_count += 1
logger.info(
f"Updated content for {updated_count} chunks with image analysis results"
)
logger = setup_logger()
@@ -286,6 +413,10 @@ def search_postprocessing(
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order.
# This way the user experience isn't delayed by the LLM step
if get_search_time_image_analysis_enabled():
update_image_sections_with_query(
retrieved_sections, search_query.query, llm
)
_log_top_section_links(search_query.search_type.value, retrieved_sections)
yield retrieved_sections
sections_yielded = True
@@ -323,6 +454,13 @@ def search_postprocessing(
)
else:
_log_top_section_links(search_query.search_type.value, reranked_sections)
# Add the image processing step here
if get_search_time_image_analysis_enabled():
update_image_sections_with_query(
reranked_sections, search_query.query, llm
)
yield reranked_sections
llm_selected_section_ids = (

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from typing import Tuple
from uuid import UUID
from fastapi import HTTPException
@@ -11,6 +12,7 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import Row
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
@@ -375,24 +377,33 @@ def delete_chat_session(
db_session.commit()
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
def get_chat_sessions_older_than(
days_old: int, db_session: Session
) -> list[tuple[UUID | None, UUID]]:
"""
Retrieves chat sessions older than a specified number of days.
Args:
days_old: The number of days to consider as "old".
db_session: The database session.
Returns:
A list of tuples, where each tuple contains the user_id (can be None) and the chat_session_id of an old chat session.
"""
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions = db_session.execute(
old_sessions: Sequence[Row[Tuple[UUID | None, UUID]]] = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time
)
).fetchall()
for user_id, session_id in old_sessions:
try:
delete_chat_session(
user_id, session_id, db_session, include_deleted=True, hard_delete=True
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
# convert old_sessions to a conventional list of tuples
returned_sessions: list[tuple[UUID | None, UUID]] = [
(user_id, session_id) for user_id, session_id in old_sessions
]
return returned_sessions
def get_chat_message(

View File

@@ -360,18 +360,13 @@ def backend_update_credential_json(
db_session.commit()
def delete_credential(
def _delete_credential_internal(
credential: Credential,
credential_id: int,
user: User | None,
db_session: Session,
force: bool = False,
) -> None:
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
"""Internal utility function to handle the actual deletion of a credential"""
associated_connectors = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.credential_id == credential_id)
@@ -416,6 +411,35 @@ def delete_credential(
db_session.commit()
def delete_credential_for_user(
credential_id: int,
user: User,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential that belongs to a specific user"""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
_delete_credential_internal(credential, credential_id, db_session, force)
def delete_credential(
credential_id: int,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential regardless of ownership (admin function)"""
credential = fetch_credential_by_id(credential_id, db_session)
if credential is None:
raise ValueError(f"Credential by provided id {credential_id} does not exist")
_delete_credential_internal(credential, credential_id, db_session, force)
def create_initial_public_credential(db_session: Session) -> None:
error_msg = (
"DB is not in a valid initial state."

View File

@@ -63,6 +63,9 @@ class IndexModelStatus(str, PyEnum):
PRESENT = "PRESENT"
FUTURE = "FUTURE"
def is_current(self) -> bool:
return self == IndexModelStatus.PRESENT
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
@@ -83,3 +86,11 @@ class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"
class EmbeddingPrecision(str, PyEnum):
# matches vespa tensor type
# only support float / bfloat16 for now, since there's not a
# good reason to specify anything else
BFLOAT16 = "bfloat16"
FLOAT = "float"

View File

@@ -46,7 +46,13 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
from onyx.db.enums import (
AccessType,
EmbeddingPrecision,
IndexingMode,
SyncType,
SyncStatus,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@@ -716,6 +722,23 @@ class SearchSettings(Base):
ForeignKey("embedding_provider.provider_type"), nullable=True
)
# Whether switching to this model should re-index all connectors in the background
# if no re-index is needed, will be ignored. Only used during the switch-over process.
background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# allows for quantization -> less memory usage for a small performance hit
embedding_precision: Mapped[EmbeddingPrecision] = mapped_column(
Enum(EmbeddingPrecision, native_enum=False)
)
# can be used to reduce dimensionality of vectors and save memory with
# a small performance hit. More details in the `Reducing embedding dimensions`
# section here:
# https://platform.openai.com/docs/guides/embeddings#embedding-models
# If not specified, will just use the model_dim without any reduction.
# NOTE: this is only currently available for OpenAI models
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
@@ -797,6 +820,12 @@ class SearchSettings(Base):
self.multipass_indexing, self.model_name, self.provider_type
)
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
@staticmethod
def can_use_large_chunks(
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
@@ -1761,6 +1790,7 @@ class ChannelConfig(TypedDict):
channel_name: str | None # None for default channel config
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
is_ephemeral: NotRequired[bool] # defaults to False
respond_member_group_list: NotRequired[list[str]]
answer_filters: NotRequired[list[AllowedAnswerFilters]]
# If None then no follow up

View File

@@ -209,13 +209,21 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
is_default_persona: bool | None = create_persona_request.is_default_persona
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
if user and user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
if user:
# Curators can edit default personas, but not make them
if (
user.role == UserRole.CURATOR
or user.role == UserRole.GLOBAL_CURATOR
):
is_default_persona = None
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
persona = upsert_persona(
persona_id=persona_id,
@@ -241,7 +249,7 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
is_default_persona=is_default_persona,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -428,7 +436,7 @@ def upsert_persona(
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
is_default_persona: bool | None = None,
label_ids: list[int] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
@@ -523,7 +531,11 @@ def upsert_persona(
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
existing_persona.is_default_persona = is_default_persona
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None
else existing_persona.is_default_persona
)
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -575,7 +587,9 @@ def upsert_persona(
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
is_default_persona=is_default_persona
if is_default_persona is not None
else False,
labels=labels or [],
)
db_session.add(new_persona)

View File

@@ -148,3 +148,28 @@ def upsert_pgfilestore(
db_session.commit()
return pgfilestore
def save_bytes_to_pgfilestore(
db_session: Session,
raw_bytes: bytes,
media_type: str,
identifier: str,
display_name: str,
file_origin: FileOrigin = FileOrigin.OTHER,
) -> PGFileStore:
"""
Saves raw bytes to PGFileStore and returns the resulting record.
"""
file_name = f"{file_origin.name.lower()}_{identifier}"
lobj_oid = create_populate_lobj(BytesIO(raw_bytes), db_session)
pgfilestore = upsert_pgfilestore(
file_name=file_name,
display_name=display_name,
file_origin=file_origin,
file_type=media_type,
lobj_oid=lobj_oid,
db_session=db_session,
commit=True,
)
return pgfilestore

View File

@@ -14,6 +14,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
@@ -59,12 +60,15 @@ def create_search_settings(
index_name=search_settings.index_name,
provider_type=search_settings.provider_type,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
background_reindex_enabled=search_settings.background_reindex_enabled,
)
db_session.add(embedding_model)
@@ -305,6 +309,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
@@ -322,6 +327,7 @@ def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,

View File

@@ -0,0 +1,53 @@
import random
from datetime import datetime
from datetime import timedelta
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatSession
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
"""
with get_session_with_current_tenant() as db_session:
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
rows = db_session.query(ChatSession).all()
for row in rows:
row.time_created = datetime.utcnow() - timedelta(
days=random.randint(0, days)
)
row.time_updated = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
root_message = get_or_create_root_message(row.id, db_session)
for x in range(0, num_messages):
chat_message = create_new_chat_message(
row.id,
root_message,
f"pytest_message_{x}",
None,
0,
MessageType.USER,
db_session,
)
chat_message.time_sent = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
db_session.commit()
db_session.commit()

View File

@@ -8,10 +8,12 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_default_document_index
from onyx.key_value_store.factory import get_kv_store
from onyx.utils.logger import setup_logger
@@ -19,7 +21,49 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def check_index_swap(db_session: Session) -> SearchSettings | None:
def _perform_index_swap(
db_session: Session,
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
) -> None:
"""Swap the indices and expire the old one."""
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
document_index.ensure_indices_exist(
primary_embedding_dim=secondary_search_settings.final_embedding_dim,
primary_embedding_precision=secondary_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
@@ -27,52 +71,45 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
search_settings = get_secondary_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
if not search_settings:
if not secondary_search_settings:
return None
# If the secondary search settings are not configured to reindex in the background,
# we can just swap over instantly
if not secondary_search_settings.background_reindex_enabled:
current_search_settings = get_current_search_settings(db_session)
_perform_index_swap(
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
return current_search_settings
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
# function is correct. The unique_cc_indexings are specifically for the existing cc-pairs
old_search_settings = None
if unique_cc_indexings > cc_pair_count:
logger.error("More unique indexings than cc pairs, should not occur")
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
_perform_index_swap(
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
update_search_settings_status(
search_settings=search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if cc_pair_count > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
old_search_settings = current_search_settings
return old_search_settings

View File

@@ -6,6 +6,7 @@ from typing import Any
from onyx.access.models import DocumentAccess
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.model_server_models import Embedding
@@ -145,17 +146,21 @@ class Verifiable(abc.ABC):
@abc.abstractmethod
def ensure_indices_exist(
self,
index_embedding_dim: int,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
"""
Verify that the document index exists and is consistent with the expectations in the code.
Parameters:
- index_embedding_dim: Vector dimensionality for the vector similarity part of the search
- primary_embedding_dim: Vector dimensionality for the vector similarity part of the search
- primary_embedding_precision: Precision of the vector similarity part of the search
- secondary_index_embedding_dim: Vector dimensionality of the secondary index being built
behind the scenes. The secondary index should only be built when switching
embedding models therefore this dim should be different from the primary index.
- secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index
"""
raise NotImplementedError
@@ -164,6 +169,7 @@ class Verifiable(abc.ABC):
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
"""
Register multitenant indices with the document index.

View File

@@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME {
summary: dynamic
}
# Title embedding (x1)
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular
@@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME {
}
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular
@@ -55,6 +55,9 @@ schema DANSWER_CHUNK_NAME {
field blurb type string {
indexing: summary | attribute
}
field image_file_name type string {
indexing: summary | attribute
}
# https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it
field source_type type string {
indexing: summary | attribute

View File

@@ -5,4 +5,7 @@
<allow
until="DATE_REPLACEMENT"
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
<allow
until='DATE_REPLACEMENT'
comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
</validation-overrides>

View File

@@ -31,6 +31,7 @@ from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
from onyx.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE
from onyx.document_index.vespa_constants import MAX_OR_CONDITIONS
@@ -130,6 +131,7 @@ def _vespa_hit_to_inference_chunk(
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
source_type=fields[SOURCE_TYPE],
image_file_name=fields.get(IMAGE_FILE_NAME),
title=fields.get(TITLE),
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
boost=fields.get(BOOST, 1),
@@ -211,6 +213,7 @@ def _get_chunks_via_visit_api(
# Check if the response contains any documents
response_data = response.json()
if "documents" in response_data:
for document in response_data["documents"]:
if filters.access_control_list:
@@ -310,6 +313,11 @@ def query_vespa(
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
f"Exception: {str(e)}"
+ (
f"\nResponse: {e.response.text}"
if isinstance(e, httpx.HTTPStatusError)
else ""
)
)
raise httpx.HTTPError(error_base) from e

View File

@@ -26,6 +26,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
@@ -63,6 +64,7 @@ from onyx.document_index.vespa_constants import DATE_REPLACEMENT
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
@@ -112,6 +114,21 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
return "\n".join(doc_lines)
def _replace_template_values_in_schema(
schema_template: str,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> str:
return (
schema_template.replace(
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
)
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
)
def add_ngrams_to_schema(schema_content: str) -> str:
# Add the match blocks containing gram and gram-size to title and content fields
schema_content = re.sub(
@@ -163,8 +180,10 @@ class VespaIndex(DocumentIndex):
def ensure_indices_exist(
self,
index_embedding_dim: int,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
if MULTI_TENANT:
logger.info(
@@ -221,18 +240,29 @@ class VespaIndex(DocumentIndex):
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = _replace_template_values_in_schema(
schema_template,
self.index_name,
primary_embedding_dim,
primary_embedding_precision,
)
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
if secondary_index_embedding_dim is None:
raise ValueError("Secondary index embedding dimension is required")
if secondary_index_embedding_precision is None:
raise ValueError("Secondary index embedding precision is required")
upcoming_schema = _replace_template_values_in_schema(
schema_template,
self.secondary_index_name,
secondary_index_embedding_dim,
secondary_index_embedding_precision,
)
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
@@ -251,6 +281,7 @@ class VespaIndex(DocumentIndex):
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
if not MULTI_TENANT:
raise ValueError("Multi-tenant is not enabled")
@@ -309,13 +340,14 @@ class VespaIndex(DocumentIndex):
for i, index_name in enumerate(indices):
embedding_dim = embedding_dims[i]
embedding_precision = embedding_precisions[i]
logger.info(
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
)
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
schema = _replace_template_values_in_schema(
schema_template, index_name, embedding_dim, embedding_precision
)
schema = schema.replace(
TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else ""
)

View File

@@ -32,6 +32,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import EMBEDDINGS
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
from onyx.document_index.vespa_constants import METADATA
from onyx.document_index.vespa_constants import METADATA_LIST
@@ -198,13 +199,13 @@ def _index_vespa_chunk(
# which only calls VespaIndex.update
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
BOOST: chunk.boost,
}
if multitenant:
if chunk.tenant_id:
vespa_document_fields[TENANT_ID] = chunk.tenant_id
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
logger.debug(f'Indexing to URL "{vespa_url}"')
res = http_client.post(

View File

@@ -6,6 +6,7 @@ from onyx.configs.app_configs import VESPA_TENANT_PORT
from onyx.configs.constants import SOURCE_TYPE
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
@@ -76,6 +77,7 @@ PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
RECENCY_BIAS = "recency_bias"
HIDDEN = "hidden"
IMAGE_FILE_NAME = "image_file_name"
# Specific to Vespa, needed for highlighting matching keywords / section
CONTENT_SUMMARY = "content_summary"
@@ -93,6 +95,7 @@ YQL_BASE = (
f"{SEMANTIC_IDENTIFIER}, "
f"{TITLE}, "
f"{SECTION_CONTINUATION}, "
f"{IMAGE_FILE_NAME}, "
f"{BOOST}, "
f"{HIDDEN}, "
f"{DOC_UPDATED_AT}, "

View File

@@ -9,15 +9,17 @@ from email.parser import Parser as EmailParser
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Dict
from typing import IO
from typing import List
from typing import Tuple
import chardet
import docx # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from docx import Document
from docx import Document as DocxDocument
from fastapi import UploadFile
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
@@ -31,10 +33,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
TEXT_SECTION_SEPARATOR = "\n\n"
PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
@@ -49,7 +49,6 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
".yaml",
]
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".pdf",
".docx",
@@ -58,6 +57,16 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".eml",
".epub",
".html",
".png",
".jpg",
".jpeg",
".webp",
]
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
"image/webp",
]
@@ -67,11 +76,13 @@ def is_text_file_extension(file_name: str) -> bool:
def get_file_ext(file_path_or_name: str | Path) -> str:
_, extension = os.path.splitext(file_path_or_name)
# standardize all extensions to be lowercase so that checks against
# VALID_FILE_EXTENSIONS and similar will work as intended
return extension.lower()
def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
@@ -79,17 +90,18 @@ def is_valid_file_ext(ext: str) -> bool:
def is_text_file(file: IO[bytes]) -> bool:
"""
checks if the first 1024 bytes only contain printable or whitespace characters
if it does, then we say its a plaintext file
if it does, then we say it's a plaintext file
"""
raw_data = file.read(1024)
file.seek(0)
text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
return all(c in text_chars for c in raw_data)
def detect_encoding(file: IO[bytes]) -> str:
raw_data = file.read(50000)
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
file.seek(0)
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
return encoding
@@ -99,14 +111,14 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
)
# To include additional metadata in the search index, add a .onyx_metadata.json file
# to the zip file. This file should contain a list of objects with the following format:
# [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }]
def load_files_from_zip(
zip_file_io: IO,
ignore_macos_resource_fork_files: bool = True,
ignore_dirs: bool = True,
) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]:
"""
If there's a .onyx_metadata.json in the zip, attach those metadata to each subfile.
"""
with zipfile.ZipFile(zip_file_io, "r") as zip_file:
zip_metadata = {}
try:
@@ -118,24 +130,31 @@ def load_files_from_zip(
# convert list of dicts to dict of dicts
zip_metadata = {d["filename"]: d for d in zip_metadata}
except json.JSONDecodeError:
logger.warn(f"Unable to load {DANSWER_METADATA_FILENAME}")
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
except KeyError:
logger.info(f"No {DANSWER_METADATA_FILENAME} file")
for file_info in zip_file.infolist():
with zip_file.open(file_info.filename, "r") as file:
if ignore_dirs and file_info.is_dir():
continue
if ignore_dirs and file_info.is_dir():
continue
if (
ignore_macos_resource_fork_files
and is_macos_resource_fork_file(file_info.filename)
) or file_info.filename == DANSWER_METADATA_FILENAME:
continue
yield file_info, file, zip_metadata.get(file_info.filename, {})
if (
ignore_macos_resource_fork_files
and is_macos_resource_fork_file(file_info.filename)
) or file_info.filename == DANSWER_METADATA_FILENAME:
continue
with zip_file.open(file_info.filename, "r") as subfile:
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
def _extract_onyx_metadata(line: str) -> dict | None:
"""
Example: first line has:
<!-- DANSWER_METADATA={"title": "..."} -->
or
#DANSWER_METADATA={"title":"..."}
"""
html_comment_pattern = r"<!--\s*DANSWER_METADATA=\{(.*?)\}\s*-->"
hashtag_pattern = r"#DANSWER_METADATA=\{(.*?)\}"
@@ -161,9 +180,13 @@ def read_text_file(
errors: str = "replace",
ignore_onyx_metadata: bool = True,
) -> tuple[str, dict]:
"""
For plain text files. Optionally extracts Onyx metadata from the first line.
"""
metadata = {}
file_content_raw = ""
for ind, line in enumerate(file):
# decode
try:
line = line.decode(encoding) if isinstance(line, bytes) else line
except UnicodeDecodeError:
@@ -173,131 +196,132 @@ def read_text_file(
else line
)
if ind == 0:
metadata_or_none = (
None if ignore_onyx_metadata else _extract_onyx_metadata(line)
)
if metadata_or_none is not None:
metadata = metadata_or_none
else:
file_content_raw += line
else:
file_content_raw += line
# optionally parse metadata in the first line
if ind == 0 and not ignore_onyx_metadata:
potential_meta = _extract_onyx_metadata(line)
if potential_meta is not None:
metadata = potential_meta
continue
file_content_raw += line
return file_content_raw, metadata
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
"""Extract text from a PDF file."""
# Return only the extracted text from read_pdf_file
text, _ = read_pdf_file(file, pdf_pass)
"""
Extract text from a PDF. For embedded images, a more complex approach is needed.
This is a minimal approach returning text only.
"""
text, _, _ = read_pdf_file(file, pdf_pass)
return text
def read_pdf_file(
file: IO[Any],
pdf_pass: str | None = None,
) -> tuple[str, dict]:
metadata: Dict[str, Any] = {}
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
) -> tuple[str, dict, list[tuple[bytes, str]]]:
"""
Returns the text, basic PDF metadata, and optionally extracted images.
"""
metadata: dict[str, Any] = {}
extracted_images: list[tuple[bytes, str]] = []
try:
pdf_reader = PdfReader(file)
# If marked as encrypted and a password is provided, try to decrypt
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
if pdf_pass is not None:
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
if not decrypt_success:
# By user request, keep files that are unreadable just so they
# can be discoverable by title.
return "", metadata
return "", metadata, []
elif pdf_reader.is_encrypted:
logger.warning("No Password available to decrypt pdf, returning empty")
return "", metadata
logger.warning("No Password for an encrypted PDF, returning empty text.")
return "", metadata, []
# Extract metadata from the PDF, removing leading '/' from keys if present
# This standardizes the metadata keys for consistency
metadata = {}
# Basic PDF metadata
if pdf_reader.metadata is not None:
for key, value in pdf_reader.metadata.items():
clean_key = key.lstrip("/")
if isinstance(value, str) and value.strip():
metadata[clean_key] = value
elif isinstance(value, list) and all(
isinstance(item, str) for item in value
):
metadata[clean_key] = ", ".join(value)
return (
TEXT_SECTION_SEPARATOR.join(
page.extract_text() for page in pdf_reader.pages
),
metadata,
text = TEXT_SECTION_SEPARATOR.join(
page.extract_text() for page in pdf_reader.pages
)
if extract_images:
for page_num, page in enumerate(pdf_reader.pages):
for image_file_object in page.images:
image = Image.open(io.BytesIO(image_file_object.data))
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format=image.format)
img_bytes = img_byte_arr.getvalue()
image_name = (
f"page_{page_num + 1}_image_{image_file_object.name}."
f"{image.format.lower() if image.format else 'png'}"
)
extracted_images.append((img_bytes, image_name))
return text, metadata, extracted_images
except PdfStreamError:
logger.exception("PDF file is not a valid PDF")
logger.exception("Invalid PDF file")
except Exception:
logger.exception("Failed to read PDF")
# File is still discoverable by title
# but the contents are not included as they cannot be parsed
return "", metadata
return "", metadata, []
def docx_to_text(file: IO[Any]) -> str:
def is_simple_table(table: docx.table.Table) -> bool:
for row in table.rows:
# No omitted cells
if row.grid_cols_before > 0 or row.grid_cols_after > 0:
return False
# No nested tables
if any(cell.tables for cell in row.cells):
return False
return True
def extract_cell_text(cell: docx.table._Cell) -> str:
cell_paragraphs = [para.text.strip() for para in cell.paragraphs]
return " ".join(p for p in cell_paragraphs if p) or "N/A"
def docx_to_text_and_images(
file: IO[Any],
) -> Tuple[str, List[Tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
Return (text_content, list_of_images).
"""
paragraphs = []
embedded_images: List[Tuple[bytes, str]] = []
doc = docx.Document(file)
for item in doc.iter_inner_content():
if isinstance(item, docx.text.paragraph.Paragraph):
paragraphs.append(item.text)
elif isinstance(item, docx.table.Table):
if not item.rows or not is_simple_table(item):
continue
# Grab text from paragraphs
for paragraph in doc.paragraphs:
paragraphs.append(paragraph.text)
# Every row is a new line, joined with a single newline
table_content = "\n".join(
[
",\t".join(extract_cell_text(cell) for cell in row.cells)
for row in item.rows
]
)
paragraphs.append(table_content)
# Reset position so we can re-load the doc (python-docx has read the stream)
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
# For large docs, a more robust approach is needed.
# This is a simplified example.
# Docx already has good spacing between paragraphs
return "\n".join(paragraphs)
for rel_id, rel in doc.part.rels.items():
if "image" in rel.reltype:
# image is typically in rel.target_part.blob
image_bytes = rel.target_part.blob
image_name = rel.target_part.partname
# store
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
text_content = "\n".join(paragraphs)
return text_content, embedded_images
def pptx_to_text(file: IO[Any]) -> str:
presentation = pptx.Presentation(file)
text_content = []
for slide_number, slide in enumerate(presentation.slides, start=1):
extracted_text = f"\nSlide {slide_number}:\n"
slide_text = f"\nSlide {slide_number}:\n"
for shape in slide.shapes:
if hasattr(shape, "text"):
extracted_text += shape.text + "\n"
text_content.append(extracted_text)
slide_text += shape.text + "\n"
text_content.append(slide_text)
return TEXT_SECTION_SEPARATOR.join(text_content)
@@ -305,18 +329,21 @@ def xlsx_to_text(file: IO[Any]) -> str:
workbook = openpyxl.load_workbook(file, read_only=True)
text_content = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(
",".join(map(str, row))
for row in sheet.iter_rows(min_row=1, values_only=True)
)
text_content.append(sheet_string)
rows = []
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell) if cell is not None else "" for cell in row)
rows.append(row_str)
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
return TEXT_SECTION_SEPARATOR.join(text_content)
def eml_to_text(file: IO[Any]) -> str:
text_file = io.TextIOWrapper(file, encoding=detect_encoding(file))
encoding = detect_encoding(file)
text_file = io.TextIOWrapper(file, encoding=encoding)
parser = EmailParser()
message = parser.parse(text_file)
text_content = []
for part in message.walk():
if part.get_content_type().startswith("text/plain"):
@@ -342,8 +369,8 @@ def epub_to_text(file: IO[Any]) -> str:
def file_io_to_text(file: IO[Any]) -> str:
encoding = detect_encoding(file)
file_content_raw, _ = read_text_file(file, encoding=encoding)
return file_content_raw
file_content, _ = read_text_file(file, encoding=encoding)
return file_content
def extract_file_text(
@@ -352,9 +379,13 @@ def extract_file_text(
break_on_unprocessable: bool = True,
extension: str | None = None,
) -> str:
"""
Legacy function that returns *only text*, ignoring embedded images.
For backward-compatibility in code that only wants text.
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
".docx": docx_to_text,
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
".pptx": pptx_to_text,
".xlsx": xlsx_to_text,
".eml": eml_to_text,
@@ -368,24 +399,23 @@ def extract_file_text(
return unstructured_to_text(file, file_name)
except Exception as unstructured_error:
logger.error(
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
f"Failed to process with Unstructured: {str(unstructured_error)}. "
"Falling back to normal processing."
)
# Fall through to normal processing
final_extension: str
if file_name or extension:
if extension is not None:
final_extension = extension
elif file_name is not None:
final_extension = get_file_ext(file_name)
if extension is None:
extension = get_file_ext(file_name)
if is_valid_file_ext(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)
if is_valid_file_ext(extension):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)
# Either the file somehow has no name or the extension is not one that we recognize
# If unknown extension, maybe it's a text file
file.seek(0)
if is_text_file(file):
return file_io_to_text(file)
raise ValueError("Unknown file extension and unknown text encoding")
raise ValueError("Unknown file extension or not recognized as text data")
except Exception as e:
if break_on_unprocessable:
@@ -396,20 +426,93 @@ def extract_file_text(
return ""
def extract_text_and_images(
file: IO[Any],
file_name: str,
pdf_pass: str | None = None,
) -> Tuple[str, List[Tuple[bytes, str]]]:
"""
Primary new function for the updated connector.
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
"""
try:
# Attempt unstructured if env var is set
if get_unstructured_api_key():
# If the user doesn't want embedded images, unstructured is fine
file.seek(0)
text_content = unstructured_to_text(file, file_name)
return (text_content, [])
extension = get_file_ext(file_name)
# docx example for embedded images
if extension == ".docx":
file.seek(0)
text_content, images = docx_to_text_and_images(file)
return (text_content, images)
# PDF example: we do not show complicated PDF image extraction here
# so we simply extract text for now and skip images.
if extension == ".pdf":
file.seek(0)
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
return (text_content, images)
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
# You can do something similar to docx if needed.
if extension == ".pptx":
file.seek(0)
return (pptx_to_text(file), [])
if extension == ".xlsx":
file.seek(0)
return (xlsx_to_text(file), [])
if extension == ".eml":
file.seek(0)
return (eml_to_text(file), [])
if extension == ".epub":
file.seek(0)
return (epub_to_text(file), [])
if extension == ".html":
file.seek(0)
return (parse_html_page_basic(file), [])
# If we reach here and it's a recognized text extension
if is_text_file_extension(file_name):
file.seek(0)
encoding = detect_encoding(file)
text_content_raw, _ = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
return (text_content_raw, [])
# If it's an image file or something else, we do not parse embedded images from them
# just return empty text
file.seek(0)
return ("", [])
except Exception as e:
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
return ("", [])
def convert_docx_to_txt(
file: UploadFile, file_store: FileStore, file_path: str
) -> None:
"""
Helper to convert docx to a .txt file in the same filestore.
"""
file.file.seek(0)
docx_content = file.file.read()
doc = Document(BytesIO(docx_content))
doc = DocxDocument(BytesIO(docx_content))
# Extract text from the document
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
# Join the extracted text
text_content = "\n".join(full_text)
all_paras = [p.text for p in doc.paragraphs]
text_content = "\n".join(all_paras)
txt_file_path = docx_to_txt_filename(file_path)
file_store.save_file(
@@ -422,7 +525,4 @@ def convert_docx_to_txt(
def docx_to_txt_filename(file_path: str) -> str:
"""
Convert a .docx file path to its corresponding .txt file path.
"""
return file_path.rsplit(".", 1)[0] + ".txt"

View File

@@ -0,0 +1,46 @@
"""
Centralized file type validation utilities.
"""
# Standard image MIME types supported by most vision LLMs
IMAGE_MIME_TYPES = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
]
# Image types that should be excluded from processing
EXCLUDED_IMAGE_TYPES = [
"image/bmp",
"image/tiff",
"image/gif",
"image/svg+xml",
]
def is_valid_image_type(mime_type: str) -> bool:
"""
Check if mime_type is a valid image type.
Args:
mime_type: The MIME type to check
Returns:
True if the MIME type is a valid image type, False otherwise
"""
if not mime_type:
return False
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
def is_supported_by_vision_llm(mime_type: str) -> bool:
"""
Check if this image type can be processed by vision LLMs.
Args:
mime_type: The MIME type to check
Returns:
True if the MIME type is supported by vision LLMs, False otherwise
"""
return mime_type in IMAGE_MIME_TYPES

View File

@@ -0,0 +1,129 @@
import base64
from io import BytesIO
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from PIL import Image
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.utils.logger import setup_logger
logger = setup_logger()
def prepare_image_bytes(image_data: bytes) -> str:
"""Prepare image bytes for summarization.
Resizes image if it's larger than 20MB. Encodes image as a base64 string."""
image_data = _resize_image_if_needed(image_data)
# encode image (base64)
encoded_image = _encode_image_for_llm_prompt(image_data)
return encoded_image
def summarize_image_pipeline(
llm: LLM,
image_data: bytes,
query: str | None = None,
system_prompt: str | None = None,
) -> str:
"""Pipeline to generate a summary of an image.
Resizes images if it is bigger than 20MB. Encodes image as a base64 string.
And finally uses the Default LLM to generate a textual summary of the image."""
# resize image if it's bigger than 20MB
encoded_image = prepare_image_bytes(image_data)
summary = _summarize_image(
encoded_image,
llm,
query,
system_prompt,
)
return summary
def summarize_image_with_error_handling(
llm: LLM | None,
image_data: bytes,
context_name: str,
system_prompt: str = IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
user_prompt_template: str = IMAGE_SUMMARIZATION_USER_PROMPT,
) -> str | None:
"""Wrapper function that handles error cases and configuration consistently.
Args:
llm: The LLM with vision capabilities to use for summarization
image_data: The raw image bytes
context_name: Name or title of the image for context
system_prompt: System prompt to use for the LLM
user_prompt_template: Template for the user prompt, should contain {title} placeholder
Returns:
The image summary text, or None if summarization failed or is disabled
"""
if llm is None:
return None
user_prompt = user_prompt_template.format(title=context_name)
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
def _summarize_image(
encoded_image: str,
llm: LLM,
query: str | None = None,
system_prompt: str | None = None,
) -> str:
"""Use default LLM (if it is multimodal) to generate a summary of an image."""
messages: list[BaseMessage] = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(
HumanMessage(
content=[
{"type": "text", "text": query},
{"type": "image_url", "image_url": {"url": encoded_image}},
],
),
)
try:
return message_to_string(llm.invoke(messages))
except Exception as e:
raise ValueError(f"Summarization failed. Messages: {messages}") from e
def _encode_image_for_llm_prompt(image_data: bytes) -> str:
"""Getting the base64 string."""
base64_encoded_data = base64.b64encode(image_data).decode("utf-8")
return f"data:image/jpeg;base64,{base64_encoded_data}"
def _resize_image_if_needed(image_data: bytes, max_size_mb: int = 20) -> bytes:
"""Resize image if it's larger than the specified max size in MB."""
max_size_bytes = max_size_mb * 1024 * 1024
if len(image_data) > max_size_bytes:
with Image.open(BytesIO(image_data)) as img:
# Reduce dimensions for better size reduction
img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
output = BytesIO()
# Save with lower quality for compression
img.save(output, format="JPEG", quality=85)
resized_data = output.getvalue()
return resized_data
return image_data

View File

@@ -0,0 +1,70 @@
from typing import Tuple
from sqlalchemy.orm import Session
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import Section
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
def store_image_and_create_section(
db_session: Session,
image_data: bytes,
file_name: str,
display_name: str,
media_type: str = "image/unknown",
llm: LLM | None = None,
file_origin: FileOrigin = FileOrigin.OTHER,
) -> Tuple[Section, str | None]:
"""
Stores an image in PGFileStore and creates a Section object with optional summarization.
Args:
db_session: Database session
image_data: Raw image bytes
file_name: Base identifier for the file
display_name: Human-readable name for the image
media_type: MIME type of the image
llm: Optional LLM with vision capabilities for summarization
file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.)
Returns:
Tuple containing:
- Section object with image reference and optional summary text
- The file_name in PGFileStore or None if storage failed
"""
# Storage logic
stored_file_name = None
try:
pgfilestore = save_bytes_to_pgfilestore(
db_session=db_session,
raw_bytes=image_data,
media_type=media_type,
identifier=file_name,
display_name=display_name,
file_origin=file_origin,
)
stored_file_name = pgfilestore.file_name
except Exception as e:
logger.error(f"Failed to store image: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return Section(text=""), None
# Summarization logic
summary_text = ""
if llm:
summary_text = (
summarize_image_with_error_handling(llm, image_data, display_name) or ""
)
return (
Section(text=summary_text, image_file_name=stored_file_name),
stored_file_name,
)

View File

@@ -23,12 +23,9 @@ from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
CHUNK_OVERLAP = 0
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
# overwhelm the actual contents of the chunk
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
# could be another 128 tokens leaving 256 for the actual contents
MAX_METADATA_PERCENTAGE = 0.25
CHUNK_MIN_CONTENT = 256
logger = setup_logger()
@@ -36,16 +33,8 @@ def _get_metadata_suffix_for_document_index(
metadata: dict[str, str | list[str]], include_separator: bool = False
) -> tuple[str, str]:
"""
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
and a string of all of the values for the keyword search
For example, if we have the following metadata:
{
"author": "John Doe",
"space": "Engineering"
}
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
and Engineering. The keys are repeat and much more noisy.
Returns the metadata as a natural language string representation with all of the keys and values
for the vector embedding and a string of all of the values for the keyword search.
"""
if not metadata:
return "", ""
@@ -74,12 +63,17 @@ def _get_metadata_suffix_for_document_index(
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
"""
Combines multiple DocAwareChunks into one large chunk (for “multipass” mode),
appending the content and adjusting source_links accordingly.
"""
merged_chunk = DocAwareChunk(
source_document=chunks[0].source_document,
chunk_id=chunks[0].chunk_id,
blurb=chunks[0].blurb,
content=chunks[0].content,
source_links=chunks[0].source_links or {},
image_file_name=None,
section_continuation=(chunks[0].chunk_id > 0),
title_prefix=chunks[0].title_prefix,
metadata_suffix_semantic=chunks[0].metadata_suffix_semantic,
@@ -103,6 +97,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
"""
Generates larger “grouped” chunks by combining sets of smaller chunks.
"""
large_chunks = []
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
chunk_group = chunks[i : i + LARGE_CHUNK_RATIO]
@@ -172,23 +169,60 @@ class Chunker:
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
# Join the tokens to reconstruct the text
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks
def _extract_blurb(self, text: str) -> str:
"""
Extract a short blurb from the text (first chunk of size `blurb_size`).
"""
texts = self.blurb_splitter.split_text(text)
if not texts:
return ""
return texts[0]
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
"""
For “multipass” mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
"""
if self.mini_chunk_splitter and chunk_text.strip():
return self.mini_chunk_splitter.split_text(chunk_text)
return None
# ADDED: extra param image_url to store in the chunk
def _create_chunk(
self,
document: Document,
chunks_list: list[DocAwareChunk],
text: str,
links: dict[int, str],
is_continuation: bool = False,
title_prefix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
image_file_name: str | None = None,
) -> None:
"""
Helper to create a new DocAwareChunk, append it to chunks_list.
"""
new_chunk = DocAwareChunk(
source_document=document,
chunk_id=len(chunks_list),
blurb=self._extract_blurb(text),
content=text,
source_links=links or {0: ""},
image_file_name=image_file_name,
section_continuation=is_continuation,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
)
chunks_list.append(new_chunk)
def _chunk_document(
self,
document: Document,
@@ -198,122 +232,156 @@ class Chunker:
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Loops through sections of the document, adds metadata and converts them into chunks.
Loops through sections of the document, converting them into one or more chunks.
If a section has an image_link, we treat it as a dedicated chunk.
"""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
def _create_chunk(
text: str,
links: dict[int, str],
is_continuation: bool = False,
) -> DocAwareChunk:
return DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=self._extract_blurb(text),
content=text,
source_links=links or {0: ""},
section_continuation=is_continuation,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
)
section_link_text: str
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""
# If there is no useful content, not even the title, just drop it
# ADDED: if the Section has an image link
image_url = section.image_file_name
# If there is no useful content, skip
if not section_text and (not document.title or section_idx > 0):
# If a section is empty and the document has no title, we can just drop it. We return a list of
# DocAwareChunks where each one contains the necessary information needed down the line for indexing.
# There is no concern about dropping whole documents from this list, it should not cause any indexing failures.
logger.warning(
f"Skipping section {section.text} from document "
f"{document.semantic_identifier} due to empty text after cleaning "
f"with link {section_link_text}"
f"Skipping empty or irrelevant section in doc "
f"{document.semantic_identifier}, link={section_link_text}"
)
continue
# CASE 1: If this is an image section, force a separate chunk
if image_url:
# First, if we have any partially built text chunk, finalize it
if chunk_text.strip():
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
is_continuation=False,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
chunk_text = ""
link_offsets = {}
# Create a chunk specifically for this image
# (If the section has text describing the image, use that as content)
self._create_chunk(
document,
chunks,
section_text,
links={0: section_link_text}
if section_link_text
else {}, # No text offsets needed for images
image_file_name=image_url,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
# Continue to next section
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.tokenize(section_text))
# Large sections are considered self-contained/unique
# Therefore, they start a new chunk and are not concatenated
# at the end by other sections
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
if chunk_text:
chunks.append(_create_chunk(chunk_text, link_offsets))
link_offsets = {}
if chunk_text.strip():
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
chunk_text = ""
link_offsets = {}
split_texts = self.chunk_splitter.split_text(section_text)
for i, split_text in enumerate(split_texts):
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and
# Tokenizer only runs if STRICT_CHUNK_TOKEN_LIMIT is true
len(self.tokenizer.tokenize(split_text)) > content_token_limit
and len(self.tokenizer.tokenize(split_text))
> content_token_limit
):
# If STRICT_CHUNK_TOKEN_LIMIT is true, manually check
# the token count of each split text to ensure it is
# not larger than the content_token_limit
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
for j, small_chunk in enumerate(smaller_chunks):
self._create_chunk(
document,
chunks,
small_chunk,
{0: section_link_text},
is_continuation=(j != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
else:
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
is_continuation=(i != 0),
)
self._create_chunk(
document,
chunks,
split_text,
{0: section_link_text},
is_continuation=(i != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
)
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.tokenize(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
# In the case where the whole section is shorter than a chunk, either add
# to chunk or start a new one
next_section_tokens = (
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
if chunk_text:
chunk_text += SECTION_SEPARATOR
chunk_text += section_text
link_offsets[current_offset] = section_link_text
else:
chunks.append(_create_chunk(chunk_text, link_offsets))
# finalize the existing chunk
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets,
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
# start a new chunk
link_offsets = {0: section_link_text}
chunk_text = section_text
# Once we hit the end, if we're still in the process of building a chunk, add what we have.
# If there is only whitespace left then don't include it. If there are no chunks at all
# from the doc, we can just create a single chunk with the title.
# finalize any leftover text chunk
if chunk_text.strip() or not chunks:
chunks.append(
_create_chunk(
chunk_text,
link_offsets or {0: section_link_text},
)
self._create_chunk(
document,
chunks,
chunk_text,
link_offsets or {0: ""}, # safe default
False,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
)
# If the chunk does not have any useable content, it will not be indexed
return chunks
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
@@ -321,10 +389,12 @@ class Chunker:
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
# Title prep
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.tokenize(title_prefix))
# Metadata prep
metadata_suffix_semantic = ""
metadata_suffix_keyword = ""
metadata_tokens = 0
@@ -337,19 +407,20 @@ class Chunker:
)
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
# If metadata is too large, skip it in the semantic content
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
# context, there is no limit for the keyword component
metadata_suffix_semantic = ""
metadata_tokens = 0
# Adjust content token limit to accommodate title + metadata
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
# Not enough space left, so revert to full chunk without the prefix
content_token_limit = self.chunk_token_limit
title_prefix = ""
metadata_suffix_semantic = ""
# Chunk the document
normal_chunks = self._chunk_document(
document,
title_prefix,
@@ -358,6 +429,7 @@ class Chunker:
content_token_limit,
)
# Optional “multipass” large chunk creation
if self.enable_multipass and self.enable_large_chunks:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
@@ -371,9 +443,8 @@ class Chunker:
"""
final_chunks: list[DocAwareChunk] = []
for document in documents:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")
if self.callback and self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")
chunks = self._handle_single_document(document)
final_chunks.extend(chunks)

View File

@@ -38,6 +38,7 @@ class IndexingEmbedder(ABC):
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
reduced_dimension: int | None,
callback: IndexingHeartbeatInterface | None,
):
self.model_name = model_name
@@ -60,6 +61,7 @@ class IndexingEmbedder(ABC):
api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
reduced_dimension=reduced_dimension,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@@ -87,6 +89,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
callback: IndexingHeartbeatInterface | None = None,
):
super().__init__(
@@ -99,6 +102,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url,
api_version,
deployment_name,
reduced_dimension,
callback,
)
@@ -219,6 +223,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
reduced_dimension=search_settings.reduced_dimension,
callback=callback,
)

View File

@@ -464,12 +464,29 @@ def index_doc_batch(
),
)
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
all_returned_doc_ids = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Successful IDs: {successful_doc_ids}"
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
)
last_modified_ids = []

View File

@@ -5,6 +5,7 @@ from pydantic import Field
from onyx.access.models import DocumentAccess
from onyx.connectors.models import Document
from onyx.db.enums import EmbeddingPrecision
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
from shared_configs.model_server_models import Embedding
@@ -28,6 +29,7 @@ class BaseChunk(BaseModel):
content: str
# Holds the link and the offsets into the raw Chunk text
source_links: dict[int, str] | None
image_file_name: str | None
# True if this Chunk's start is not at the start of a Section
section_continuation: bool
@@ -143,10 +145,20 @@ class IndexingSetting(EmbeddingModelDetail):
model_dim: int
index_name: str | None
multipass_indexing: bool
embedding_precision: EmbeddingPrecision
reduced_dimension: int | None = None
background_reindex_enabled: bool = True
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
@classmethod
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
return cls(
@@ -158,6 +170,9 @@ class IndexingSetting(EmbeddingModelDetail):
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
background_reindex_enabled=search_settings.background_reindex_enabled,
)

View File

@@ -6,12 +6,14 @@ from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@@ -86,6 +88,48 @@ def get_llms_for_persona(
return _create_llm(model), _create_llm(fast_model)
def get_default_llm_with_vision(
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM | None:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
with get_session_context_manager() as db_session:
llm_providers = fetch_existing_llm_providers(db_session)
if not llm_providers:
return None
for provider in llm_providers:
model_name = provider.default_model_name
fast_model_name = (
provider.fast_default_model_name or provider.default_model_name
)
if not model_name or not fast_model_name:
continue
if model_supports_image_input(model_name, provider.provider):
return get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
raise ValueError("No LLM provider found that supports image input")
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,

View File

@@ -51,7 +51,7 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as oauth_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
@@ -323,7 +323,7 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, oauth_router)
include_router_with_global_prefix_prepended(application, standard_oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -89,6 +89,7 @@ class EmbeddingModel:
callback: IndexingHeartbeatInterface | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -100,6 +101,7 @@ class EmbeddingModel:
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.reduced_dimension = reduced_dimension
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
@@ -188,6 +190,7 @@ class EmbeddingModel:
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
api_url=self.api_url,
reduced_dimension=self.reduced_dimension,
)
start_time = time.time()
@@ -300,6 +303,7 @@ class EmbeddingModel:
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
reduced_dimension=search_settings.reduced_dimension,
)

View File

@@ -31,12 +31,18 @@ from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.icons import source_to_github_img_link
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessage
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageChannelConfig
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageMessageInfo
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import build_continue_in_web_ui_id
from onyx.onyxbot.slack.utils import build_feedback_id
from onyx.onyxbot.slack.utils import build_publish_ephemeral_message_id
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.onyxbot.slack.utils import translate_vespa_highlight_to_slack
from onyx.utils.text_processing import decode_escapes
@@ -105,6 +111,77 @@ def _build_qa_feedback_block(
)
def _build_ephemeral_publication_block(
channel_id: str,
chat_message_id: int,
message_info: SlackMessageInfo,
original_question_ts: str,
channel_conf: ChannelConfig,
feedback_reminder_id: str | None = None,
) -> Block:
# check whether the message is in a thread
if (
message_info is not None
and message_info.msg_to_respond is not None
and message_info.thread_to_respond is not None
and (message_info.msg_to_respond == message_info.thread_to_respond)
):
respond_ts = None
else:
respond_ts = original_question_ts
action_values_ephemeral_message_channel_config = (
ActionValuesEphemeralMessageChannelConfig(
channel_name=channel_conf.get("channel_name"),
respond_tag_only=channel_conf.get("respond_tag_only"),
respond_to_bots=channel_conf.get("respond_to_bots"),
is_ephemeral=channel_conf.get("is_ephemeral", False),
respond_member_group_list=channel_conf.get("respond_member_group_list"),
answer_filters=channel_conf.get("answer_filters"),
follow_up_tags=channel_conf.get("follow_up_tags"),
show_continue_in_web_ui=channel_conf.get("show_continue_in_web_ui", False),
)
)
action_values_ephemeral_message_message_info = (
ActionValuesEphemeralMessageMessageInfo(
bypass_filters=message_info.bypass_filters,
channel_to_respond=message_info.channel_to_respond,
msg_to_respond=message_info.msg_to_respond,
email=message_info.email,
sender_id=message_info.sender_id,
thread_messages=[],
is_bot_msg=message_info.is_bot_msg,
is_bot_dm=message_info.is_bot_dm,
thread_to_respond=respond_ts,
)
)
action_values_ephemeral_message = ActionValuesEphemeralMessage(
original_question_ts=original_question_ts,
feedback_reminder_id=feedback_reminder_id,
chat_message_id=chat_message_id,
message_info=action_values_ephemeral_message_message_info,
channel_conf=action_values_ephemeral_message_channel_config,
)
return ActionsBlock(
block_id=build_publish_ephemeral_message_id(original_question_ts),
elements=[
ButtonElement(
action_id=SHOW_EVERYONE_ACTION_ID,
text="📢 Share with Everyone",
value=action_values_ephemeral_message.model_dump_json(),
),
ButtonElement(
action_id=KEEP_TO_YOURSELF_ACTION_ID,
text="🤫 Keep to Yourself",
value=action_values_ephemeral_message.model_dump_json(),
),
],
)
def get_document_feedback_blocks() -> Block:
return SectionBlock(
text=(
@@ -486,16 +563,21 @@ def build_slack_response_blocks(
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
offer_ephemeral_publication: bool = False,
expecting_search_result: bool = False,
skip_restated_question: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
if not skip_restated_question:
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
else:
restate_question_block = []
if expecting_search_result:
answer_blocks = _build_qa_response_blocks(
@@ -520,12 +602,36 @@ def build_slack_response_blocks(
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
if (
channel_conf
and channel_conf.get("follow_up_tags") is not None
and not channel_conf.get("is_ephemeral", False)
):
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
publish_ephemeral_message_block = []
if (
offer_ephemeral_publication
and answer.chat_message_id is not None
and message_info.msg_to_respond is not None
and channel_conf is not None
):
publish_ephemeral_message_block.append(
_build_ephemeral_publication_block(
channel_id=message_info.channel_to_respond,
chat_message_id=answer.chat_message_id,
original_question_ts=message_info.msg_to_respond,
message_info=message_info,
channel_conf=channel_conf,
feedback_reminder_id=feedback_reminder_id,
)
)
ai_feedback_block: list[Block] = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
@@ -547,6 +653,7 @@ def build_slack_response_blocks(
all_blocks = (
restate_question_block
+ answer_blocks
+ publish_ephemeral_message_block
+ ai_feedback_block
+ citations_divider
+ citations_blocks

View File

@@ -2,6 +2,8 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
SHOW_EVERYONE_ACTION_ID = "show-everyone"
KEEP_TO_YOURSELF_ACTION_ID = "keep-to-yourself"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"

View File

@@ -1,3 +1,4 @@
import json
from typing import Any
from typing import cast
@@ -5,21 +6,32 @@ from slack_sdk import WebClient
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.views import View
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.webhook import WebhookClient
from onyx.chat.models import ChatOnyxBotResponse
from onyx.chat.models import CitationInfo
from onyx.chat.models import QADocsResponse
from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.db.users import get_user_by_email
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.blocks import get_document_feedback_blocks
from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import FeedbackVisibility
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from onyx.onyxbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
@@ -35,15 +47,48 @@ from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import get_feedback_visibility
from onyx.onyxbot.slack.utils import read_slack_thread
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import TenantSocketModeClient
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _convert_db_doc_id_to_document_ids(
citation_dict: dict[int, int], top_documents: list[SavedSearchDoc]
) -> list[CitationInfo]:
citation_list_with_document_id = []
for citation_num, db_doc_id in citation_dict.items():
if db_doc_id is not None:
matching_doc = next(
(d for d in top_documents if d.db_doc_id == db_doc_id), None
)
if matching_doc:
citation_list_with_document_id.append(
CitationInfo(
citation_num=citation_num, document_id=matching_doc.document_id
)
)
return citation_list_with_document_id
def _build_citation_list(chat_message_detail: ChatMessageDetail) -> list[CitationInfo]:
citation_dict = chat_message_detail.citations
if citation_dict is None:
return []
else:
top_documents = (
chat_message_detail.context_docs.top_documents
if chat_message_detail.context_docs
else []
)
citation_list = _convert_db_doc_id_to_document_ids(citation_dict, top_documents)
return citation_list
def handle_doc_feedback_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
@@ -58,7 +103,7 @@ def handle_doc_feedback_button(
external_id = build_feedback_id(query_event_id, doc_id, doc_rank)
channel_id = req.payload["container"]["channel_id"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
data = View(
type="modal",
@@ -84,7 +129,7 @@ def handle_generate_answer_button(
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
user_id = req.payload["user"]["id"]
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
email = expert_info.email if expert_info else None
@@ -106,7 +151,7 @@ def handle_generate_answer_button(
# tell the user that we're working on it
# Send an ephemeral message to the user that we're generating the answer
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
receiver_ids=[user_id],
@@ -142,6 +187,178 @@ def handle_generate_answer_button(
)
def handle_publish_ephemeral_message_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
action_id: str,
) -> None:
"""
This function handles the Share with Everyone/Keep for Yourself buttons
for ephemeral messages.
"""
channel_id = req.payload["channel"]["id"]
ephemeral_message_ts = req.payload["container"]["message_ts"]
slack_sender_id = req.payload["user"]["id"]
response_url = req.payload["response_url"]
webhook = WebhookClient(url=response_url)
# The additional data required that was added to buttons.
# Specifically, this contains the message_info, channel_conf information
# and some additional attributes.
value_dict = json.loads(req.payload["actions"][0]["value"])
original_question_ts = value_dict.get("original_question_ts")
if not original_question_ts:
raise ValueError("Missing original_question_ts in the payload")
if not ephemeral_message_ts:
raise ValueError("Missing ephemeral_message_ts in the payload")
feedback_reminder_id = value_dict.get("feedback_reminder_id")
slack_message_info = SlackMessageInfo(**value_dict["message_info"])
channel_conf = value_dict.get("channel_conf")
user_email = value_dict.get("message_info", {}).get("email")
chat_message_id = value_dict.get("chat_message_id")
# Obtain onyx_user and chat_message information
if not chat_message_id:
raise ValueError("Missing chat_message_id in the payload")
with get_session_with_current_tenant() as db_session:
onyx_user = get_user_by_email(user_email, db_session)
if not onyx_user:
raise ValueError("Cannot determine onyx_user_id from email in payload")
try:
chat_message = get_chat_message(chat_message_id, onyx_user.id, db_session)
except ValueError:
chat_message = get_chat_message(
chat_message_id, None, db_session
) # is this good idea?
except Exception as e:
logger.error(f"Failed to get chat message: {e}")
raise e
chat_message_detail = translate_db_message_to_chat_message_detail(chat_message)
# construct the proper citation format and then the answer in the suitable format
# we need to construct the blocks.
citation_list = _build_citation_list(chat_message_detail)
onyx_bot_answer = ChatOnyxBotResponse(
answer=chat_message_detail.message,
citations=citation_list,
chat_message_id=chat_message_id,
docs=QADocsResponse(
top_documents=chat_message_detail.context_docs.top_documents
if chat_message_detail.context_docs
else [],
predicted_flow=None,
predicted_search=None,
applied_source_filters=None,
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
llm_selected_doc_indices=None,
error_msg=None,
)
# Note: we need to use the webhook and the respond_url to update/delete ephemeral messages
if action_id == SHOW_EVERYONE_ACTION_ID:
# Convert to non-ephemeral message in thread
try:
webhook.send(
response_type="ephemeral",
text="",
blocks=[],
replace_original=True,
delete_original=True,
)
except Exception as e:
logger.error(f"Failed to send webhook: {e}")
# remove handling of empheremal block and add AI feedback.
all_blocks = build_slack_response_blocks(
answer=onyx_bot_answer,
message_info=slack_message_info,
channel_conf=channel_conf,
use_citations=True,
feedback_reminder_id=feedback_reminder_id,
skip_ai_feedback=False,
offer_ephemeral_publication=False,
skip_restated_question=True,
)
try:
# Post in thread as non-ephemeral message
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
receiver_ids=None, # If respond_member_group_list is set, send to them. TODO: check!
text="Hello! Onyx has some results for you!",
blocks=all_blocks,
thread_ts=original_question_ts,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
send_as_ephemeral=False,
)
except Exception as e:
logger.error(f"Failed to publish ephemeral message: {e}")
raise e
elif action_id == KEEP_TO_YOURSELF_ACTION_ID:
# Keep as ephemeral message in channel or thread, but remove the publish button and add feedback button
changed_blocks = build_slack_response_blocks(
answer=onyx_bot_answer,
message_info=slack_message_info,
channel_conf=channel_conf,
use_citations=True,
feedback_reminder_id=feedback_reminder_id,
skip_ai_feedback=False,
offer_ephemeral_publication=False,
skip_restated_question=True,
)
try:
if slack_message_info.thread_to_respond is not None:
# There seems to be a bug in slack where an update within the thread
# actually leads to the update to be posted in the channel. Therefore,
# for now we delete the original ephemeral message and post a new one
# if the ephemeral message is in a thread.
webhook.send(
response_type="ephemeral",
text="",
blocks=[],
replace_original=True,
delete_original=True,
)
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
receiver_ids=[slack_sender_id],
text="Your personal response, sent as an ephemeral message.",
blocks=changed_blocks,
thread_ts=original_question_ts,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
send_as_ephemeral=True,
)
else:
# This works fine if the ephemeral message is in the channel
webhook.send(
response_type="ephemeral",
text="Your personal response, sent as an ephemeral message.",
blocks=changed_blocks,
replace_original=True,
delete_original=False,
)
except Exception as e:
logger.error(f"Failed to send webhook: {e}")
def handle_slack_feedback(
feedback_id: str,
feedback_type: str,
@@ -153,13 +370,20 @@ def handle_slack_feedback(
) -> None:
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
# Get Onyx user from Slack ID
expert_info = expert_info_from_slack_id(
user_id_to_post_confirmation, client, user_cache={}
)
email = expert_info.email if expert_info else None
with get_session_with_current_tenant() as db_session:
onyx_user = get_user_by_email(email, db_session) if email else None
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
feedback_text="",
chat_message_id=message_id,
user_id=None, # no "user" for Slack bot for now
user_id=onyx_user.id if onyx_user else None,
db_session=db_session,
)
remove_scheduled_feedback_reminder(
@@ -213,7 +437,7 @@ def handle_slack_feedback(
else:
msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer"
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel_id_to_post_confirmation,
text=msg,
@@ -232,7 +456,7 @@ def handle_followup_button(
action_id = cast(str, action.get("block_id"))
channel_id = req.payload["container"]["channel_id"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
@@ -265,7 +489,7 @@ def handle_followup_button(
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
text="Received your request for more help",
@@ -315,7 +539,7 @@ def handle_followup_resolved_button(
) -> None:
channel_id = req.payload["container"]["channel_id"]
message_ts = req.payload["container"]["message_ts"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
clicker_name = get_clicker_name(req, client)
@@ -349,7 +573,7 @@ def handle_followup_resolved_button(
resolved_block = SectionBlock(text=msg_text)
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
text="Your request for help as been addressed!",

View File

@@ -18,7 +18,7 @@ from onyx.onyxbot.slack.handlers.handle_standard_answers import (
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
from onyx.onyxbot.slack.utils import fetch_user_ids_from_groups
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import slack_usage_report
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import setup_logger
@@ -29,7 +29,7 @@ logger_base = setup_logger()
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender_id:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=details.channel_to_respond,
thread_ts=details.msg_to_respond,
@@ -202,7 +202,7 @@ def handle_message(
# which would just respond to the sender
if send_to and is_bot_msg:
if sender_id:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=[sender_id],
@@ -220,6 +220,7 @@ def handle_message(
add_slack_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
# standard answers should be published in a thread
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,

View File

@@ -33,7 +33,7 @@ from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.handlers.utils import slackify_message_thread
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import SlackRateLimiter
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.query_and_chat.models import CreateChatMessageRequest
@@ -82,12 +82,38 @@ def handle_regular_answer(
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
# Capture whether response mode for channel is ephemeral. Even if the channel is set
# to respond with an ephemeral message, we still send as non-ephemeral if
# the message is a dm with the Onyx bot.
send_as_ephemeral = (
slack_channel_config.channel_config.get("is_ephemeral", False)
and not message_info.is_bot_dm
)
# If the channel mis configured to respond with an ephemeral message,
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
# This will make documents privately accessible to the user available to Onyx Bot answers.
# Otherwise - if not ephemeral or DM to Onyx Bot - we must use None as the user to restrict
# to public docs.
user = None
if message_info.is_bot_dm:
if message_info.is_bot_dm or send_as_ephemeral:
if message_info.email:
with get_session_with_current_tenant() as db_session:
user = get_user_by_email(message_info.email, db_session)
target_thread_ts = (
None
if send_as_ephemeral and len(message_info.thread_messages) < 2
else message_ts_to_respond_to
)
target_receiver_ids = (
[message_info.sender_id]
if message_info.sender_id and send_as_ephemeral
else receiver_ids
)
document_set_names: list[str] | None = None
prompt = None
# If no persona is specified, use the default search based persona
@@ -134,11 +160,10 @@ def handle_regular_answer(
history_messages = messages[:-1]
single_message_history = slackify_message_thread(history_messages) or None
# Always check for ACL permissions, also for documnt sets that were explicitly added
# to the Bot by the Administrator. (Change relative to earlier behavior where all documents
# in an attached document set were available to all users in the channel.)
bypass_acl = False
if slack_channel_config.persona and slack_channel_config.persona.document_sets:
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
if not message_ts_to_respond_to and not is_bot_msg:
# if the message is not "/onyx" command, then it should have a message ts to respond to
@@ -219,12 +244,13 @@ def handle_regular_answer(
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,
receiver_ids=target_receiver_ids,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
# In case of failures, don't keep the reaction there permanently
@@ -242,32 +268,36 @@ def handle_regular_answer(
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=receiver_ids,
receiver_ids=target_receiver_ids,
text="Hello! Onyx has some results for you!",
blocks=[
SectionBlock(
text="Onyx is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
respond_in_thread(
# If the channel is ephemeral, we don't need to send a message to the user since they will already see the message
if target_receiver_ids and not send_as_ephemeral:
respond_in_thread_or_channel(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
return False
@@ -316,12 +346,13 @@ def handle_regular_answer(
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,
receiver_ids=target_receiver_ids,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
return True
@@ -349,15 +380,27 @@ def handle_regular_answer(
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,
receiver_ids=target_receiver_ids,
text="Found no citations or quotes when trying to answer.",
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
return True
if (
send_as_ephemeral
and target_receiver_ids is not None
and len(target_receiver_ids) == 1
):
offer_ephemeral_publication = True
skip_ai_feedback = True
else:
offer_ephemeral_publication = False
skip_ai_feedback = False if feedback_reminder_id else True
all_blocks = build_slack_response_blocks(
message_info=message_info,
answer=answer,
@@ -365,31 +408,39 @@ def handle_regular_answer(
use_citations=True, # No longer supporting quotes
feedback_reminder_id=feedback_reminder_id,
expecting_search_result=expecting_search_result,
offer_ephemeral_publication=offer_ephemeral_publication,
skip_ai_feedback=skip_ai_feedback,
)
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=[message_info.sender_id]
if message_info.is_bot_msg and message_info.sender_id
else receiver_ids,
receiver_ids=target_receiver_ids,
text="Hello! Onyx has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
send_as_ephemeral=send_as_ephemeral,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /onyx message
# so we shouldn't send_team_member_message
if receiver_ids and message_ts_to_respond_to is not None:
if (
target_receiver_ids
and message_ts_to_respond_to is not None
and not send_as_ephemeral
and target_thread_ts is not None
):
send_team_member_message(
client=client,
channel=channel,
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
receiver_ids=target_receiver_ids,
send_as_ephemeral=send_as_ephemeral,
)
return False

View File

@@ -2,7 +2,7 @@ from slack_sdk import WebClient
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import MessageType
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
@@ -32,8 +32,10 @@ def send_team_member_message(
client: WebClient,
channel: str,
thread_ts: str,
receiver_ids: list[str] | None = None,
send_as_ephemeral: bool = False,
) -> None:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
text=(
@@ -41,4 +43,6 @@ def send_team_member_message(
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=thread_ts,
receiver_ids=None,
send_as_ephemeral=send_as_ephemeral,
)

View File

@@ -57,7 +57,9 @@ from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from onyx.onyxbot.slack.handlers.handle_buttons import handle_doc_feedback_button
from onyx.onyxbot.slack.handlers.handle_buttons import handle_followup_button
@@ -67,6 +69,9 @@ from onyx.onyxbot.slack.handlers.handle_buttons import (
from onyx.onyxbot.slack.handlers.handle_buttons import (
handle_generate_answer_button,
)
from onyx.onyxbot.slack.handlers.handle_buttons import (
handle_publish_ephemeral_message_button,
)
from onyx.onyxbot.slack.handlers.handle_buttons import handle_slack_feedback
from onyx.onyxbot.slack.handlers.handle_message import handle_message
from onyx.onyxbot.slack.handlers.handle_message import (
@@ -81,7 +86,7 @@ from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id
from onyx.onyxbot.slack.utils import read_slack_thread
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
from onyx.onyxbot.slack.utils import rephrase_slack_message
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import TenantSocketModeClient
from onyx.redis.redis_pool import get_redis_client
from onyx.server.manage.models import SlackBotTokens
@@ -667,7 +672,11 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
feedback_msg_reminder = cast(str, action.get("value"))
feedback_id = cast(str, action.get("block_id"))
channel_id = cast(str, req.payload["container"]["channel_id"])
thread_ts = cast(str, req.payload["container"]["thread_ts"])
thread_ts = cast(
str,
req.payload["container"].get("thread_ts")
or req.payload["container"].get("message_ts"),
)
else:
logger.error("Unable to process feedback. Action not found")
return
@@ -783,7 +792,7 @@ def apologize_for_fail(
details: SlackMessageInfo,
client: TenantSocketModeClient,
) -> None:
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=details.channel_to_respond,
thread_ts=details.msg_to_respond,
@@ -859,6 +868,14 @@ def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> No
if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]:
# AI Answer feedback
return process_feedback(req, client)
elif action["action_id"] in [
SHOW_EVERYONE_ACTION_ID,
KEEP_TO_YOURSELF_ACTION_ID,
]:
# Publish ephemeral message or keep hidden in main channel
return handle_publish_ephemeral_message_button(
req, client, action["action_id"]
)
elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID:
# Activation of the "source feedback" button
return handle_doc_feedback_button(req, client)

View File

@@ -1,3 +1,5 @@
from typing import Literal
from pydantic import BaseModel
from onyx.chat.models import ThreadMessage
@@ -13,3 +15,37 @@ class SlackMessageInfo(BaseModel):
bypass_filters: bool # User has tagged @OnyxBot
is_bot_msg: bool # User is using /OnyxBot
is_bot_dm: bool # User is direct messaging to OnyxBot
# Models used to encode the relevant data for the ephemeral message actions
class ActionValuesEphemeralMessageMessageInfo(BaseModel):
bypass_filters: bool | None
channel_to_respond: str | None
msg_to_respond: str | None
email: str | None
sender_id: str | None
thread_messages: list[ThreadMessage] | None
is_bot_msg: bool | None
is_bot_dm: bool | None
thread_to_respond: str | None
class ActionValuesEphemeralMessageChannelConfig(BaseModel):
channel_name: str | None
respond_tag_only: bool | None
respond_to_bots: bool | None
is_ephemeral: bool
respond_member_group_list: list[str] | None
answer_filters: list[
Literal["well_answered_postfilter", "questionmark_prefilter"]
] | None
follow_up_tags: list[str] | None
show_continue_in_web_ui: bool
class ActionValuesEphemeralMessage(BaseModel):
original_question_ts: str | None
feedback_reminder_id: str | None
chat_message_id: int
message_info: ActionValuesEphemeralMessageMessageInfo
channel_conf: ActionValuesEphemeralMessageChannelConfig

View File

@@ -184,7 +184,7 @@ def _build_error_block(error_message: str) -> Block:
backoff=2,
logger=cast(logging.Logger, logger),
)
def respond_in_thread(
def respond_in_thread_or_channel(
client: WebClient,
channel: str,
thread_ts: str | None,
@@ -193,6 +193,7 @@ def respond_in_thread(
receiver_ids: list[str] | None = None,
metadata: Metadata | None = None,
unfurl: bool = True,
send_as_ephemeral: bool | None = True,
) -> list[str]:
if not text and not blocks:
raise ValueError("One of `text` or `blocks` must be provided")
@@ -236,6 +237,7 @@ def respond_in_thread(
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
@@ -299,6 +301,12 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_publish_ephemeral_message_id(
original_question_ts: str,
) -> str:
return "publish_ephemeral_message__" + original_question_ts
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
@@ -539,7 +547,7 @@ def read_slack_thread(
# If auto-detected filters are on, use the second block for the actual answer
# The first block is the auto-detected filters
if message.startswith("_Filters"):
if message is not None and message.startswith("_Filters"):
if len(blocks) < 2:
logger.warning(f"Only filter blocks found: {reply}")
continue
@@ -611,7 +619,7 @@ class SlackRateLimiter:
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: str | None
) -> None:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,

View File

@@ -0,0 +1,22 @@
# Used for creating embeddings of images for vector search
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
You are an assistant for summarizing images for retrieval.
Summarize the content of the following image and be as precise as possible.
The summary will be embedded and used to retrieve the original image.
Therefore, write a concise summary of the image that is optimized for retrieval.
"""
# Prompt for generating image descriptions with filename context
IMAGE_SUMMARIZATION_USER_PROMPT = """
The image has the file name '{title}'.
Describe precisely and concisely what the image shows.
"""
# Used for analyzing images in response to user queries at search time
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
"You are an AI assistant specialized in describing images.\n"
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
"Focus on aspects of the image that are relevant to the user's question.\n"
"Be specific and detailed about visual elements that directly address the query.\n"
)

View File

@@ -55,7 +55,11 @@ def _create_indexable_chunks(
# The section is not really used past this point since we have already done the other processing
# for the chunking and embedding.
sections=[
Section(text=preprocessed_doc["content"], link=preprocessed_doc["url"])
Section(
text=preprocessed_doc["content"],
link=preprocessed_doc["url"],
image_file_name=None,
)
],
source=DocumentSource.WEB,
semantic_identifier=preprocessed_doc["title"],
@@ -93,6 +97,7 @@ def _create_indexable_chunks(
document_sets=set(),
boost=DEFAULT_BOOST,
large_chunk_id=None,
image_file_name=None,
)
chunks.append(chunk)

View File

@@ -13,6 +13,7 @@ from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
from onyx.db.credentials import delete_credential
from onyx.db.credentials import delete_credential_for_user
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import fetch_credentials_by_source_for_user
from onyx.db.credentials import fetch_credentials_for_user
@@ -88,7 +89,7 @@ def delete_credential_by_id_admin(
db_session: Session = Depends(get_session),
) -> StatusResponse:
"""Same as the user endpoint, but can delete any credential (not just the user's own)"""
delete_credential(db_session=db_session, credential_id=credential_id, user=None)
delete_credential(db_session=db_session, credential_id=credential_id)
return StatusResponse(
success=True, message="Credential deleted successfully", data=credential_id
)
@@ -242,7 +243,7 @@ def delete_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
delete_credential(
delete_credential_for_user(
credential_id,
user,
db_session,
@@ -259,7 +260,7 @@ def force_delete_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
delete_credential(credential_id, user, db_session, True)
delete_credential_for_user(credential_id, user, db_session, True)
return StatusResponse(
success=True, message="Credential deleted successfully", data=credential_id

View File

@@ -181,6 +181,7 @@ class SlackChannelConfigCreationRequest(BaseModel):
channel_name: str
respond_tag_only: bool = False
respond_to_bots: bool = False
is_ephemeral: bool = False
show_continue_in_web_ui: bool = False
enable_auto_filters: bool = False
# If no team members, assume respond in the channel to everyone

Some files were not shown because too many files have changed in this diff Show More