1
0
forked from github/onyx

refactor: file processing (#5136)

* file processing refactor

* mypy

* CW comments

* address CW
This commit is contained in:
Evan Lohn
2025-08-07 17:34:35 -07:00
committed by GitHub
parent bd4bd00cef
commit 297720c132
8 changed files with 154 additions and 147 deletions

View File

@@ -67,7 +67,7 @@ def generate_chat_messages_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
@@ -99,7 +99,7 @@ def generate_user_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
@@ -8,10 +10,12 @@ import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
def extract_ids_from_runnable_connector(
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_generator = None
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
doc_batch_processing_func = document_batch_to_ids
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
)(lambda x: x)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
return all_connector_doc_ids

View File

@@ -25,6 +25,28 @@ TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,

View File

@@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
if not mime_type:
return False
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
def is_supported_by_vision_llm(mime_type: str) -> bool:

View File

@@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
# Use a separate session to avoid committing the caller's transaction
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))

View File

@@ -867,31 +867,27 @@ def index_doc_batch(
user_file_id_to_raw_text: dict[int, str] = {}
for document_id in updatable_ids:
# Only calculate token counts for documents that have a user file ID
if (
document_id in doc_id_to_user_file_id
and doc_id_to_user_file_id[document_id] is not None
):
user_file_id = doc_id_to_user_file_id[document_id]
if not user_file_id:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content))
if llm_tokenizer
else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
user_file_id = doc_id_to_user_file_id.get(document_id)
if user_file_id is None:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.

View File

@@ -1,3 +1,4 @@
import io
import json
import mimetypes
import os
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -438,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool:
)
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
@@ -487,12 +492,17 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
# For mypy, actual check happens at start of function
assert file.filename is not None
# Special handling for docx files - only store the plaintext version
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
docx_file_id = convert_docx_to_txt(file, file_store)
deduped_file_paths.append(docx_file_id)
# Special handling for doc files - only store the plaintext version
file_type = mime_type_to_chat_file_type(file.content_type)
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file.file, file.filename or "")
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=file_origin,
file_type="text/plain",
)
deduped_file_paths.append(text_file_id)
deduped_file_names.append(file.filename)
continue
@@ -520,7 +530,7 @@ def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files)
return upload_files(files, FileOrigin.OTHER)
@router.get("/admin/connector")

View File

@@ -1,6 +1,5 @@
import asyncio
import datetime
import io
import json
import os
import time
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
@@ -63,9 +61,7 @@ from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.user_documents import create_user_files
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
@@ -717,106 +713,65 @@ def upload_files_for_chat(
):
raise HTTPException(
status_code=400,
detail="File size must be less than 20MB",
detail="Images must be less than 20MB",
)
file_store = get_default_file_store()
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
file_type = mime_type_to_chat_file_type(file.content_type)
file_content = file.file.read() # Read the file content
# NOTE: Image conversion to JPEG used to be enforced here.
# This was removed to:
# 1. Preserve original file content for downloads
# 2. Maintain transparency in formats like PNG
# 3. Ameliorate issue with file conversion
file_content_io = io.BytesIO(file_content)
new_content_type = file.content_type
# Store the file normally
file_id = file_store.save_file(
content=file_content_io,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=new_content_type or file_type.value,
# 5) Create a user file for each uploaded file
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 4) If the file is a doc, extract text and store that separately
if file_type == ChatFileType.DOC:
# Re-wrap bytes in a fresh BytesIO so we start at position 0
extracted_text_io = io.BytesIO(file_content)
extracted_text = extract_file_text(
file=extracted_text_io, # use the bytes we already read
file_name=file.filename or "",
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type="text/plain",
)
# Return the text file as the "main" file descriptor for doc types
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
else:
file_info.append((file_id, file.filename, file_type))
# 5) Create a user file for each uploaded file
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
return {
"files": [
{"id": file_id, "type": file_type, "name": file_name}
for file_id, file_name, file_type in file_info
{
"id": user_file.file_id,
"type": mime_type_to_chat_file_type(user_file.content_type),
"name": user_file.name,
}
for user_file in user_files
]
}