mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 01:22:45 +00:00
Compare commits
1 Commits
v0.20.0-cl
...
pr-3208
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2357d9234f |
@@ -546,6 +546,13 @@ except json.JSONDecodeError:
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
# Multimodal-Settings
|
||||
# # enable usage of summaries
|
||||
# -> add summaries to Vespa when indexing and therefore use them in the answer generation as well
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED = (
|
||||
os.environ.get("CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
|
||||
@@ -93,3 +93,29 @@ BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
# Custom System Prompt for Image Summarization
|
||||
# Results may be better if the prompt is in the language of the main language of the source
|
||||
# if no prompt provided by user a default prompt is used:
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = (
|
||||
os.environ.get("CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT")
|
||||
or """
|
||||
You are an assistant for summarizing images for retrieval.
|
||||
Summarize the content of the following image and be as precise as possible.
|
||||
The summary will be embedded and used to retrieve the original image.
|
||||
Therefore, write a concise summary of the image that is optimized for retrieval.
|
||||
"""
|
||||
)
|
||||
|
||||
# Custom User Prompt for Image Summarization
|
||||
# Results may be better if the prompt is in the language of the main language of the source
|
||||
# if no prompt provided by user a default prompt is used:
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT = (
|
||||
os.environ.get("CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT")
|
||||
or """
|
||||
The image has the file name '{title}' and is embedded on a Confluence page with the title '{page_title}'.
|
||||
Describe precisely and concisely what the image shows in the context of the page and what it is used for.
|
||||
The following is the XML source text of the page:
|
||||
{confluence_xml}
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -4,7 +4,8 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED
|
||||
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -29,6 +30,7 @@ from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
@@ -46,6 +48,7 @@ _ATTACHMENT_EXPANSION_FIELDS = [
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
]
|
||||
|
||||
_RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
@@ -129,7 +132,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
# If image summarization is enabled, but not supported by the default LLM, raise an error.
|
||||
if CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED:
|
||||
llm, _ = get_default_llms()
|
||||
if not llm.vision_support():
|
||||
raise ValueError(
|
||||
"The configured default LLM doesn't seem to have vision support for image summarization."
|
||||
)
|
||||
|
||||
self.llm = llm
|
||||
else:
|
||||
self.llm = None
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
@@ -194,58 +207,53 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
return comment_string
|
||||
|
||||
def _convert_object_to_document(
|
||||
self, confluence_object: dict[str, Any]
|
||||
) -> Document | None:
|
||||
def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None:
|
||||
"""
|
||||
Takes in a confluence object, extracts all metadata, and converts it into a document.
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
If it's a page, it extracts the text, adds the comments for the document text.
|
||||
If it's an attachment, it just downloads the attachment and converts that into a document.
|
||||
If image summarization is enabled (env var CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED is set),
|
||||
images are extracted and summarized by the default LLM.
|
||||
"""
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
object_text = None
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
return None
|
||||
logger.notice(f"processing page: {object_url}")
|
||||
|
||||
# Get space name
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": confluence_object["space"]["name"]
|
||||
"Wiki Space Name": page["space"]["name"]
|
||||
}
|
||||
|
||||
# Get labels
|
||||
label_dicts = confluence_object["metadata"]["labels"]["results"]
|
||||
label_dicts = page["metadata"]["labels"]["results"]
|
||||
page_labels = [label["name"] for label in label_dicts]
|
||||
if page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
# Get last modified and author email
|
||||
last_modified = datetime_from_string(confluence_object["version"]["when"])
|
||||
author_email = confluence_object["version"].get("by", {}).get("email")
|
||||
last_modified = datetime_from_string(page["version"]["when"])
|
||||
author_email = page["version"].get("by", {}).get("email")
|
||||
|
||||
# Extract text from page
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=page,
|
||||
fetched_titles={page.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(page["id"])
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parsable
|
||||
return None
|
||||
|
||||
return Document(
|
||||
id=object_url,
|
||||
sections=[Section(link=object_url, text=object_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=confluence_object["title"],
|
||||
semantic_identifier=page["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author_email)] if author_email else None
|
||||
@@ -269,11 +277,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
doc = self._convert_page_to_document(page)
|
||||
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
@@ -282,16 +290,37 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
# TODO: maybe should add time filter as well?
|
||||
# fetch attachments of each page directly after each page
|
||||
# to be able to use the XML text of each page as context when summarizing the images of each page
|
||||
# (only if CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED = True, otherwise images will be skipped)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
|
||||
confluence_xml = page["body"]["storage"]["value"]
|
||||
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
content = attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_context=confluence_xml,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
if content:
|
||||
# get link to attachment in webui
|
||||
webui_link = _attachment_to_webui_link(
|
||||
self.confluence_client, attachment
|
||||
)
|
||||
# add prefix to text to mark attachments as such
|
||||
text = f"## Text representation of attachment {attachment['title']}:\n{content}"
|
||||
doc.sections.append(Section(text=text, link=webui_link))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -348,7 +377,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
media_type = attachment["metadata"]["mediaType"]
|
||||
if not validate_attachment_filetype(media_type):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
@@ -370,11 +400,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
def _attachment_to_webui_link(
|
||||
confluence_client: OnyxConfluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
"""Extracts the webui link to images."""
|
||||
return confluence_client.url + attachment["_links"]["webui"]
|
||||
|
||||
@@ -2,12 +2,18 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
import bs4 # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
from onyx.configs.chat_configs import CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT, CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.file_processing.image_summarization import summarize_image_pipeline
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
@@ -17,6 +23,7 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -179,36 +186,49 @@ def extract_text_from_confluence_html(
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]
|
||||
def validate_attachment_filetype(media_type: str) -> bool:
|
||||
if media_type.startswith("video/") or media_type == "application/gliffy+json":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
page_context: str,
|
||||
llm: LLM,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
media_type = attachment["metadata"]["mediaType"]
|
||||
|
||||
if media_type.startswith("video/") or media_type == "application/gliffy+json":
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
f"Cannot convert attachment {download_link} with unsupported media type to text: {media_type}."
|
||||
)
|
||||
return None
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
if CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED:
|
||||
try:
|
||||
# get images from page
|
||||
summarization = _summarize_image_attachment(
|
||||
attachment=attachment,
|
||||
page_context=page_context,
|
||||
confluence_client=confluence_client,
|
||||
llm=llm,
|
||||
)
|
||||
return summarization
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to summarize image: {attachment}", exc_info=e)
|
||||
else:
|
||||
logger.warning(
|
||||
"Image summarization is disabled. Skipping image attachment %s",
|
||||
download_link,
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
@@ -222,9 +242,17 @@ def attachment_to_content(
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
if not extracted_text:
|
||||
logger.warning(
|
||||
"Text conversion of attachment %s resulted in an empty string.",
|
||||
download_link,
|
||||
)
|
||||
return None
|
||||
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"Skipping attachment {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
@@ -279,3 +307,41 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def _attachment_to_download_link(
|
||||
confluence_client: OnyxConfluence, attachment: dict[str, Any]
|
||||
) -> str:
|
||||
"""Extracts the download link to images."""
|
||||
return confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
|
||||
def _summarize_image_attachment(
|
||||
attachment: Dict[str, Any],
|
||||
page_context: str,
|
||||
confluence_client: OnyxConfluence,
|
||||
llm: LLM,
|
||||
) -> str:
|
||||
title = attachment["title"]
|
||||
download_link = _attachment_to_download_link(confluence_client, attachment)
|
||||
|
||||
try:
|
||||
# get image from url
|
||||
image_data = confluence_client.get(
|
||||
download_link, absolute=True, not_json_response=True
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(
|
||||
f"Failed to fetch image for summarization from {download_link}"
|
||||
) from e
|
||||
|
||||
# get image summary
|
||||
# format user prompt: add page title and XML content of page to provide a better summarization
|
||||
user_prompt = CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT.format(
|
||||
title=title, page_title=attachment["title"], confluence_xml=page_context
|
||||
)
|
||||
summary = summarize_image_pipeline(
|
||||
llm, image_data, user_prompt, CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
103
backend/onyx/file_processing/image_summarization.py
Normal file
103
backend/onyx/file_processing/image_summarization.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def summarize_image_pipeline(
|
||||
llm: LLM,
|
||||
image_data: bytes,
|
||||
query: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | None:
|
||||
"""Pipeline to generate a summary of an image.
|
||||
Resizes images if it is bigger than 20MB. Encodes image as a base64 string.
|
||||
And finally uses the Default LLM to generate a textual summary of the image."""
|
||||
# resize image if it's bigger than 20MB
|
||||
image_data = _resize_image_if_needed(image_data)
|
||||
|
||||
# encode image (base64)
|
||||
encoded_image = _encode_image(image_data)
|
||||
|
||||
summary = _summarize_image(
|
||||
encoded_image,
|
||||
llm,
|
||||
query,
|
||||
system_prompt,
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def _summarize_image(
|
||||
encoded_image: str,
|
||||
llm: LLM,
|
||||
query: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str | None:
|
||||
"""Use default LLM (if it is multimodal) to generate a summary of an image."""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": query},
|
||||
{"type": "image_url", "image_url": {"url": encoded_image}},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
model_output = message_to_string(llm.invoke(messages))
|
||||
|
||||
return model_output
|
||||
|
||||
except Exception as e:
|
||||
if CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
# Summary of this image will be empty
|
||||
# prevents and infinity retry-loop of the indexing if single summaries fail
|
||||
# for example because content filters got triggert...
|
||||
logger.warning(f"Summarization failed with error: {e}.")
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Summarization failed. Messages: {messages}") from e
|
||||
|
||||
|
||||
def _encode_image(image_data: bytes) -> str:
|
||||
"""Getting the base64 string."""
|
||||
base64_encoded_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
return f"data:image/jpeg;base64,{base64_encoded_data}"
|
||||
|
||||
|
||||
def _resize_image_if_needed(image_data: bytes, max_size_mb: int = 20) -> bytes:
|
||||
"""Resize image if it's larger than the specified max size in MB."""
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
|
||||
if len(image_data) > max_size_bytes:
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
logger.info("resizing image...")
|
||||
|
||||
# Reduce dimensions for better size reduction
|
||||
img.thumbnail((800, 800), Image.Resampling.LANCZOS)
|
||||
output = BytesIO()
|
||||
|
||||
# Save with lower quality for compression
|
||||
img.save(output, format="JPEG", quality=85)
|
||||
resized_data = output.getvalue()
|
||||
|
||||
return resized_data
|
||||
|
||||
return image_data
|
||||
@@ -225,6 +225,7 @@ class Chunker:
|
||||
|
||||
for section_idx, section in enumerate(document.sections):
|
||||
section_text = clean_text(section.text)
|
||||
|
||||
section_link_text = section.link or ""
|
||||
# If there is no useful content, not even the title, just drop it
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
@@ -287,10 +288,16 @@ class Chunker:
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
# In the case where the whole section is shorter than a chunk, either add
|
||||
# to chunk or start a new one
|
||||
# If the next section is from a page attachment a new chunk is started
|
||||
# (to ensure that especially image summaries are stored in separate chunks.)
|
||||
prefix = "## Text representation of attachment "
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if (
|
||||
next_section_tokens + current_token_count <= content_token_limit
|
||||
and prefix not in section_text
|
||||
):
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
|
||||
@@ -372,6 +372,10 @@ class DefaultMultiLLM(LLM):
|
||||
# # Return the lesser of available tokens or configured max
|
||||
# return min(self._max_output_tokens, available_output_tokens)
|
||||
|
||||
def vision_support(self) -> bool | None:
|
||||
model = f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
return litellm.supports_vision(model=model)
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
@@ -84,6 +84,9 @@ class LLM(abc.ABC):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
def vision_support(self) -> bool | None:
|
||||
return None
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
|
||||
40
backend/tests/multimodal_confluence/confluence_connector.py
Normal file
40
backend/tests/multimodal_confluence/confluence_connector.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED = True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"vision_support_value",
|
||||
[
|
||||
True, # Should not raise an error
|
||||
# False, # Should raise a ValueError
|
||||
# None # Should raise a ValueError
|
||||
],
|
||||
)
|
||||
@patch("danswer.llm.factory.get_default_llms")
|
||||
def test_vision_support(mock_get_default_llms, vision_support_value):
|
||||
"""Test different cases for vision support."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.vision_support.return_value = vision_support_value
|
||||
mock_get_default_llms.return_value = (mock_llm, None)
|
||||
|
||||
if vision_support_value:
|
||||
confluence_connector = ConfluenceConnector(
|
||||
wiki_base="https://example.com",
|
||||
is_cloud=True,
|
||||
)
|
||||
assert confluence_connector is not None # Ensure the connector is instantiated
|
||||
else: # If vision support is False or None
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="The configured default LLM doesn't seem to have vision support for image summarization.",
|
||||
):
|
||||
ConfluenceConnector(
|
||||
wiki_base="https://example.com",
|
||||
is_cloud=True,
|
||||
)
|
||||
255
backend/tests/multimodal_confluence/image_summarization.py
Normal file
255
backend/tests/multimodal_confluence/image_summarization.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from litellm import BadRequestError
|
||||
from openai import RateLimitError
|
||||
from PIL import Image
|
||||
|
||||
from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.file_processing.image_summarization import _encode_image
|
||||
from danswer.file_processing.image_summarization import _resize_image_if_needed
|
||||
from danswer.file_processing.image_summarization import _summarize_image
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
# Mocking global variables
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED = True
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT = "Summarize this image"
|
||||
|
||||
|
||||
# Mock LLM class for testing
|
||||
class MockLLM(LLM):
|
||||
def invoke(self, messages):
|
||||
# Simulate a response object with a 'content' attribute
|
||||
class Response:
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
# Simulate successful invocation
|
||||
return Response("This is a summary of the image.")
|
||||
|
||||
def _invoke_implementation(self):
|
||||
pass
|
||||
|
||||
def _stream_implementation(self):
|
||||
pass
|
||||
|
||||
def log_model_configs(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
pass
|
||||
|
||||
|
||||
# Helper function to create a dummy image
|
||||
def create_image(size: tuple, color: str, format: str) -> bytes:
|
||||
img = Image.new("RGB", size, color)
|
||||
output = BytesIO()
|
||||
img.save(output, format=format)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
def test_encode_different_image_formats():
|
||||
"""Tests the base64 encoding of different image formats."""
|
||||
formats = ["JPEG", "PNG", "GIF", "EPS"]
|
||||
for format in formats:
|
||||
image_data = create_image((100, 100), "blue", format)
|
||||
|
||||
expected_output = "data:image/jpeg;base64," + base64.b64encode(
|
||||
image_data
|
||||
).decode("utf-8")
|
||||
|
||||
result = _encode_image(image_data)
|
||||
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_resize_image_above_max_size():
|
||||
"""Test that an image above the max size is resized."""
|
||||
image_data = create_image((2000, 2000), "red", "JPEG") # Large image
|
||||
result = _resize_image_if_needed(image_data, max_size_mb=1)
|
||||
|
||||
# Check if the resized image is below the max size
|
||||
assert len(result) < (1 * 1024 * 1024) # Size should be less than 1 MB
|
||||
|
||||
|
||||
def test_summarization_of_images():
|
||||
"""Test that summarize_image returns a valid summary."""
|
||||
encoded_image = "data:image/jpeg;base64,idFuHHIwEEOHVAA..."
|
||||
query = "What is in this image?"
|
||||
system_prompt = "You are a helpful assistant."
|
||||
llm = MockLLM()
|
||||
|
||||
result = _summarize_image(
|
||||
encoded_image=encoded_image, query=query, system_prompt=system_prompt, llm=llm
|
||||
)
|
||||
assert result == "This is a summary of the image."
|
||||
|
||||
|
||||
# Mock response for RateLimitError
|
||||
class MockResponse:
|
||||
def __init__(self):
|
||||
self.request = "mock_request" # Simulate the request attribute
|
||||
self.status_code = 429
|
||||
self.headers = {"x-request-id": "mock_request_id"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception",
|
||||
[
|
||||
(
|
||||
BadRequestError(
|
||||
"Content policy violation",
|
||||
model="model_name",
|
||||
llm_provider="provider_name",
|
||||
),
|
||||
),
|
||||
(RateLimitError("Retry limit exceeded", response=MockResponse(), body="body"),),
|
||||
],
|
||||
)
|
||||
def test_summarize_image_raises_value_error_on_failure(exception):
|
||||
llm = MockLLM()
|
||||
|
||||
global CONTINUE_ON_CONNECTOR_FAILURE
|
||||
CONTINUE_ON_CONNECTOR_FAILURE = False
|
||||
|
||||
# Set the LLM invoke method to raise the specified exception
|
||||
llm.invoke = MagicMock(side_effect=exception)
|
||||
|
||||
# Use pytest.raises to assert that the exception is raised
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
_summarize_image("encoded_image_string", llm, "test query", "system prompt")
|
||||
|
||||
# Assert that the exception message matches the expected message
|
||||
assert "Summarization failed." in str(excinfo.value)
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "exception",
|
||||
# [
|
||||
# (
|
||||
# BadRequestError(
|
||||
# "Content policy violation",
|
||||
# model="model_name",
|
||||
# llm_provider="provider_name",
|
||||
# ),
|
||||
# ),
|
||||
# (
|
||||
# RateLimitError(
|
||||
# "Retry limit exceeded", response=MockResponse(), body="body"
|
||||
# ),
|
||||
# ),
|
||||
# ],
|
||||
# )
|
||||
# def test_summarize_image_return_none_on_failure(exception):
|
||||
# llm = MockLLM()
|
||||
|
||||
# global CONTINUE_ON_CONNECTOR_FAILURE
|
||||
# CONTINUE_ON_CONNECTOR_FAILURE = True
|
||||
|
||||
# # Mock the invoke method to raise a BadRequestError
|
||||
# llm.invoke = MagicMock(side_effect=exception)
|
||||
|
||||
# # Call the summarize_image function
|
||||
# result = _summarize_image("encoded_image_string", llm, "test query", "system prompt")
|
||||
# print(result)
|
||||
# # Assert that the result is None
|
||||
# assert result is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_page_image():
|
||||
return {
|
||||
"id": 123,
|
||||
"title": "Sample Image",
|
||||
"metadata": {"mediaType": "image/png"},
|
||||
"_links": {"download": "dummy-link.mock"},
|
||||
"history": {
|
||||
"lastUpdated": {"message": "dummy Update"},
|
||||
},
|
||||
"extensions": {"fileSize": 1},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_page_no_image():
|
||||
return {
|
||||
"id": 123,
|
||||
"title": "Test Image",
|
||||
"metadata": {"mediaType": "..."},
|
||||
"_links": {"download": "dummy-link.mock"},
|
||||
"history": {
|
||||
"lastUpdated": {"message": "dummy Update"},
|
||||
},
|
||||
"extensions": {"fileSize": 1},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_xml():
|
||||
return """
|
||||
<document>
|
||||
<ac:image>
|
||||
<ri:attachment ri:filename="Sample Image" />
|
||||
</ac:image>
|
||||
<ac:structured-macro ac:name="gliffy">
|
||||
<ac:parameter ac:name="imageAttachmentId">123</ac:parameter>
|
||||
</ac:structured-macro>
|
||||
</document>
|
||||
"""
|
||||
|
||||
|
||||
def test_summarize_page_images(sample_page_image, confluence_xml):
|
||||
USER_PROMPT = "Summarize this image"
|
||||
|
||||
# Mock the Confluence client
|
||||
mock_confluence_client = MagicMock()
|
||||
|
||||
# Mock the get method to return valid base64-encoded image data
|
||||
image_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01"
|
||||
mock_confluence_client.get = MagicMock(
|
||||
return_value=image_data
|
||||
) # Return base64-encoded string
|
||||
|
||||
# Mock the _get_embedded_image_attachments function
|
||||
with patch(
|
||||
"danswer.connectors.confluence.utils.attachment_to_content",
|
||||
return_value=[sample_page_image],
|
||||
):
|
||||
# Mock the summarize_image function to return a predefined summary
|
||||
with patch(
|
||||
"danswer.file_processing.image_summarization._summarize_image",
|
||||
return_value="This is a summary of the image.",
|
||||
):
|
||||
# Mock the image summarization pipeline
|
||||
with patch(
|
||||
"danswer.file_processing.image_summarization.summarize_image_pipeline",
|
||||
return_value="This is a summary of the image.",
|
||||
):
|
||||
result = _summarize_image(
|
||||
sample_page_image,
|
||||
MockLLM(),
|
||||
mock_confluence_client,
|
||||
USER_PROMPT,
|
||||
)
|
||||
print(result)
|
||||
|
||||
assert result == "This is a summary of the image."
|
||||
|
||||
|
||||
def test_attachment_to_content_with_no_image(sample_page_no_image, confluence_xml):
|
||||
confluence_client = MagicMock()
|
||||
|
||||
with patch("danswer.connectors.confluence.utils._summarize_image_attachment"):
|
||||
result = attachment_to_content(
|
||||
confluence_client,
|
||||
sample_page_no_image,
|
||||
confluence_xml,
|
||||
MockLLM(),
|
||||
)
|
||||
print(result)
|
||||
|
||||
assert result is None
|
||||
@@ -112,10 +112,18 @@ services:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
# Seeding configuration
|
||||
|
||||
# Multimodal-Settings
|
||||
- CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED=${CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED:-}
|
||||
# (optional) Custom Prompt for Image Summarization
|
||||
- CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT=${CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT:-}
|
||||
- CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT=${CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT:-}
|
||||
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -233,6 +241,12 @@ services:
|
||||
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
|
||||
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
|
||||
|
||||
# Multimodal-Settings
|
||||
- CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED=${CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED:-}
|
||||
# (optional) Custom Prompt for Image Summarization
|
||||
- CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT=${CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT:-}
|
||||
- CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT=${CONFLUENCE_IMAGE_SUMMARIZATION_USER_PROMPT:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
|
||||
|
||||
29
deployment/docker_compose/env.multimodal-confluence.template
Normal file
29
deployment/docker_compose/env.multimodal-confluence.template
Normal file
@@ -0,0 +1,29 @@
|
||||
# This env template shows how to configure Danswer for custom multimodal use of the confluence connector
|
||||
# To use it, copy it to .env in the docker_compose directory (or the equivalent environment settings file for your deployment)
|
||||
|
||||
# Enable generation of summaries of images
|
||||
CONFLUENCE_IMAGE_SUMMARIZATION_ENABLED="True"
|
||||
|
||||
# Optionally use custom prompts to generate summaries of images
|
||||
# We recommend keeping the parameters <title>, <page_title> and especially <confluence_xml> in the user prompt
|
||||
# to provide context of the images and therefore get better / more precise summaries.
|
||||
# Additionally we recommend to use the main language of the provided pages in the prompts.
|
||||
# System Prompt: initial instructions to the model, therefore sets the context/tone/behavior.
|
||||
# default system prompt:
|
||||
# You are an assistant for summarizing images for retrieval.
|
||||
# Summarize the content of the following image and be as precise as possible.
|
||||
# The summary will be embedded and used to retrieve the original image.
|
||||
# Therefore, write a concise summary of the image that is optimized for retrieval.
|
||||
#CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = ""
|
||||
# User Prompt: directly engages with the model to elicit a response. Represents input/query from the end user.
|
||||
# default user prompt:
|
||||
# The image has the file name '{title}' and is embedded on a Confluence page with the title '{page_title}'.
|
||||
# Describe precisely and concisely what the image shows in the context of the page and what it is used for.
|
||||
# The following is the XML source text of the page:
|
||||
# {confluence_xml}
|
||||
#CONFLUENCE_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = ""
|
||||
|
||||
# If the summarization of an image fails due the indexing of the documents is restarted every x (default 30) minutes
|
||||
# if not stopped manually. To avoid this problem (and the resulting possible high costs) we highly recommend
|
||||
# setting CONTINUE_ON_CONNECTOR_FAILURE true, so the summaries for such images stay empty.
|
||||
CONTINUE_ON_CONNECTOR_FAILURE = "True"
|
||||
Reference in New Issue
Block a user