forked from github/onyx
refactor: file processing (#5136)
* file processing refactor * mypy * CW comments * address CW
This commit is contained in:
@@ -67,7 +67,7 @@ def generate_chat_messages_report(
|
||||
file_id = file_store.save_file(
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ def generate_user_report(
|
||||
file_id = file_store.save_file(
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
@@ -8,10 +10,12 @@ import httpx
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.connector_runner import batched_doc_ids
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: list[Document],
|
||||
) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
doc_batch: Iterator[list[Document]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_generator = None
|
||||
doc_batch_id_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.poll_source(start=start, end=end)
|
||||
)
|
||||
elif isinstance(runnable_connector, CheckpointedConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
checkpoint = runnable_connector.build_dummy_checkpoint()
|
||||
checkpoint_generator = runnable_connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
doc_batch_id_generator = batched_doc_ids(
|
||||
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
# this function is called per batch for rate limiting
|
||||
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
|
||||
return doc_batch_ids
|
||||
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
)(lambda x: x)
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
@@ -25,6 +25,28 @@ TimeRange = tuple[datetime, datetime]
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
def batched_doc_ids(
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
batch_size: int,
|
||||
) -> Generator[set[str], None, None]:
|
||||
batch: set[str] = set()
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
batch.add(document.id)
|
||||
elif (
|
||||
failure and failure.failed_document and failure.failed_document.document_id
|
||||
):
|
||||
batch.add(failure.failed_document.document_id)
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = set()
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||
|
||||
@@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
if not mime_type:
|
||||
return False
|
||||
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
|
||||
|
||||
def is_supported_by_vision_llm(mime_type: str) -> bool:
|
||||
|
||||
@@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
# Use a separate session to avoid committing the caller's transaction
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
|
||||
@@ -867,31 +867,27 @@ def index_doc_batch(
|
||||
user_file_id_to_raw_text: dict[int, str] = {}
|
||||
for document_id in updatable_ids:
|
||||
# Only calculate token counts for documents that have a user file ID
|
||||
if (
|
||||
document_id in doc_id_to_user_file_id
|
||||
and doc_id_to_user_file_id[document_id] is not None
|
||||
):
|
||||
user_file_id = doc_id_to_user_file_id[document_id]
|
||||
if not user_file_id:
|
||||
continue
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content))
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
user_file_id = doc_id_to_user_file_id.get(document_id)
|
||||
if user_file_id is None:
|
||||
continue
|
||||
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import AuthStatus
|
||||
from onyx.server.documents.models import AuthUrl
|
||||
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -438,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
@@ -487,12 +492,17 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
# For mypy, actual check happens at start of function
|
||||
assert file.filename is not None
|
||||
|
||||
# Special handling for docx files - only store the plaintext version
|
||||
if file.content_type and file.content_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
docx_file_id = convert_docx_to_txt(file, file_store)
|
||||
deduped_file_paths.append(docx_file_id)
|
||||
# Special handling for doc files - only store the plaintext version
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
if file_type == ChatFileType.DOC:
|
||||
extracted_text = extract_file_text(file.file, file.filename or "")
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type="text/plain",
|
||||
)
|
||||
deduped_file_paths.append(text_file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
@@ -520,7 +530,7 @@ def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files)
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
@@ -63,9 +61,7 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.user_documents import create_user_files
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -717,106 +713,65 @@ def upload_files_for_chat(
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File size must be less than 20MB",
|
||||
detail="Images must be less than 20MB",
|
||||
)
|
||||
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
||||
for file in files:
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
|
||||
file_content = file.file.read() # Read the file content
|
||||
|
||||
# NOTE: Image conversion to JPEG used to be enforced here.
|
||||
# This was removed to:
|
||||
# 1. Preserve original file content for downloads
|
||||
# 2. Maintain transparency in formats like PNG
|
||||
# 3. Ameliorate issue with file conversion
|
||||
file_content_io = io.BytesIO(file_content)
|
||||
|
||||
new_content_type = file.content_type
|
||||
|
||||
# Store the file normally
|
||||
file_id = file_store.save_file(
|
||||
content=file_content_io,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=new_content_type or file_type.value,
|
||||
# 5) Create a user file for each uploaded file
|
||||
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
|
||||
for user_file in user_files:
|
||||
# 6) Create connector
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
# 4) If the file is a doc, extract text and store that separately
|
||||
if file_type == ChatFileType.DOC:
|
||||
# Re-wrap bytes in a fresh BytesIO so we start at position 0
|
||||
extracted_text_io = io.BytesIO(file_content)
|
||||
extracted_text = extract_file_text(
|
||||
file=extracted_text_io, # use the bytes we already read
|
||||
file_name=file.filename or "",
|
||||
)
|
||||
# 7) Create credential
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type="text/plain",
|
||||
)
|
||||
# Return the text file as the "main" file descriptor for doc types
|
||||
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
|
||||
else:
|
||||
file_info.append((file_id, file.filename, file_type))
|
||||
|
||||
# 5) Create a user file for each uploaded file
|
||||
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
|
||||
for user_file in user_files:
|
||||
# 6) Create connector
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
# 7) Create credential
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# 8) Create connector credential pair
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
# 8) Create connector credential pair
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
|
||||
return {
|
||||
"files": [
|
||||
{"id": file_id, "type": file_type, "name": file_name}
|
||||
for file_id, file_name, file_type in file_info
|
||||
{
|
||||
"id": user_file.file_id,
|
||||
"type": mime_type_to_chat_file_type(user_file.content_type),
|
||||
"name": user_file.name,
|
||||
}
|
||||
for user_file in user_files
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user