mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-09 08:52:42 +00:00
Compare commits
25 Commits
edge
...
edwin/dan-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
850262cd39 | ||
|
|
6afa54d346 | ||
|
|
c5ef00c739 | ||
|
|
e8d62bb259 | ||
|
|
f3b11f57ac | ||
|
|
4724fd7adf | ||
|
|
f4f03cc282 | ||
|
|
3bcff943d1 | ||
|
|
f430f4892e | ||
|
|
26cc8f10b9 | ||
|
|
559ddf31cc | ||
|
|
9e8acdc9bb | ||
|
|
2cd20a1cbc | ||
|
|
3e0f5f3a21 | ||
|
|
565ef1e584 | ||
|
|
579a2936fb | ||
|
|
526f457545 | ||
|
|
c1fac15b67 | ||
|
|
e253b69dc6 | ||
|
|
a4607a05c3 | ||
|
|
0a5c7d8edc | ||
|
|
00c898c0ee | ||
|
|
dfd1ff533a | ||
|
|
cf032d8ba6 | ||
|
|
7c94ce37f9 |
@@ -37,6 +37,15 @@ repos:
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports are not directly imported
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/.*\.py$
|
||||
pass_filenames: false
|
||||
|
||||
# We would like to have a mypy pre-commit hook, but due to the fact that
|
||||
# pre-commit runs in it's own isolated environment, we would need to install
|
||||
# and keep in sync all dependencies so mypy has access to the appropriate type
|
||||
|
||||
@@ -2,7 +2,6 @@ import string
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
import nltk # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
@@ -60,6 +59,8 @@ def _dedupe_chunks(
|
||||
|
||||
|
||||
def download_nltk_data() -> None:
|
||||
import nltk
|
||||
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
# "wordnet": "corpora/wordnet", # Not in use
|
||||
@@ -166,7 +167,6 @@ def doc_index_retrieval(
|
||||
and query.expanded_queries.keywords_expansions
|
||||
and query.expanded_queries.semantic_expansions
|
||||
):
|
||||
|
||||
keyword_embeddings_thread = run_in_background(
|
||||
get_query_embeddings,
|
||||
query.expanded_queries.keywords_expansions,
|
||||
@@ -233,13 +233,11 @@ def doc_index_retrieval(
|
||||
# use all three retrieval methods to retrieve top chunks
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC and top_semantic_chunks is not None:
|
||||
|
||||
all_top_chunks += top_semantic_chunks
|
||||
|
||||
top_chunks = _dedupe_chunks(all_top_chunks)
|
||||
|
||||
else:
|
||||
|
||||
top_base_chunks_standard_ranking = wait_on_background(
|
||||
top_base_chunks_standard_ranking_thread
|
||||
)
|
||||
@@ -395,8 +393,7 @@ def retrieve_chunks(
|
||||
|
||||
if not top_chunks:
|
||||
logger.warning(
|
||||
f"Hybrid ({query.search_type.value.capitalize()}) search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
f"Hybrid ({query.search_type.value.capitalize()}) search returned no results with filters: {query.filters}"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ import string
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
@@ -153,6 +151,9 @@ def chunks_or_sections_to_search_docs(
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
from nltk.corpus import stopwords # type: ignore[import-untyped]
|
||||
from nltk.tokenize import word_tokenize # type: ignore[import-untyped]
|
||||
|
||||
try:
|
||||
# Re-tokenize using the NLTK tokenizer for better matching
|
||||
query = " ".join(keywords)
|
||||
|
||||
@@ -17,9 +17,6 @@ from typing import NamedTuple
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
from markitdown import FileConversionException
|
||||
from markitdown import MarkItDown
|
||||
from markitdown import UnsupportedFormatException
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
@@ -330,6 +327,12 @@ def docx_to_text_and_images(
|
||||
file_name: str = "",
|
||||
image_callback: Callable[[bytes, str], None] | None = None,
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
from markitdown import (
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
MarkItDown,
|
||||
)
|
||||
|
||||
"""
|
||||
Extract text from a docx.
|
||||
Return (text_content, list_of_images).
|
||||
@@ -372,6 +375,12 @@ def docx_to_text_and_images(
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
from markitdown import (
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
MarkItDown,
|
||||
)
|
||||
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
presentation = md.convert(to_bytesio(file))
|
||||
@@ -388,6 +397,12 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
from markitdown import (
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
MarkItDown,
|
||||
)
|
||||
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
workbook = md.convert(to_bytesio(file))
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
from nltk import ngrams # type: ignore
|
||||
from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import Float
|
||||
@@ -59,6 +58,8 @@ def _normalize_one_entity(
|
||||
attributes: dict[str, str],
|
||||
allowed_docs_temp_view_name: str | None = None,
|
||||
) -> str | None:
|
||||
from nltk import ngrams
|
||||
|
||||
"""
|
||||
Matches a single entity to the best matching entity of the same type.
|
||||
"""
|
||||
@@ -77,7 +78,6 @@ def _normalize_one_entity(
|
||||
|
||||
# step 1: find entities containing the entity_name or something similar
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# get allowed documents
|
||||
metadata = MetaData()
|
||||
if allowed_docs_temp_view_name is None:
|
||||
@@ -257,7 +257,6 @@ def normalize_entities(
|
||||
for entity, attributes, normalized_entity in zip(
|
||||
raw_entities, entity_attributes, mapping
|
||||
):
|
||||
|
||||
if normalized_entity is not None:
|
||||
normalized_entities.append(normalized_entity)
|
||||
normalized_entities_w_attributes.append(
|
||||
|
||||
@@ -15,7 +15,6 @@ import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import openai
|
||||
import requests
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
@@ -25,8 +24,6 @@ from requests import JSONDecodeError
|
||||
from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
@@ -266,6 +263,9 @@ class CloudEmbedding:
|
||||
async def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
import vertexai # type: ignore[import-untyped]
|
||||
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput # type: ignore[import-untyped]
|
||||
|
||||
if not model:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
@@ -551,8 +551,7 @@ class EmbeddingModel:
|
||||
if embed_request.manual_query_prefix or embed_request.manual_passage_prefix:
|
||||
logger.warning("Prefix provided for cloud model, which is not supported")
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
"Prefix string is not valid for cloud models. Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
if not all(embed_request.texts):
|
||||
|
||||
0
backend/scripts/__init__.py
Normal file
0
backend/scripts/__init__.py
Normal file
176
backend/scripts/check_lazy_imports.py
Normal file
176
backend/scripts/check_lazy_imports.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Set
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format
|
||||
handlers=[logging.StreamHandler()], # Output logs to console
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagerImportResult:
|
||||
"""Result of checking a file for eager imports."""
|
||||
|
||||
violation_lines: List[tuple[int, str]] # (line_number, line_content) tuples
|
||||
violated_modules: Set[str] # modules that were actually violated
|
||||
|
||||
|
||||
def find_eager_imports(
|
||||
file_path: Path, protected_modules: Set[str]
|
||||
) -> EagerImportResult:
|
||||
"""
|
||||
Find eager imports of protected modules in a given file.
|
||||
|
||||
Eager imports are top-level (module-level) imports that happen immediately
|
||||
when the module is loaded, as opposed to lazy imports that happen inside
|
||||
functions only when called.
|
||||
|
||||
Args:
|
||||
file_path: Path to Python file to check
|
||||
protected_modules: Set of module names that should only be imported lazily
|
||||
|
||||
Returns:
|
||||
EagerImportResult containing violations list and violated modules set
|
||||
"""
|
||||
violation_lines = []
|
||||
violated_modules = set()
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip comments and empty lines
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
# Only check imports at module level (indentation == 0)
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
if current_indent == 0:
|
||||
# Check for eager imports of protected modules
|
||||
for module in protected_modules:
|
||||
# Pattern 1: import module
|
||||
if re.match(rf"^import\s+{re.escape(module)}(\s|$|\.)", stripped):
|
||||
violation_lines.append((line_num, line))
|
||||
violated_modules.add(module)
|
||||
|
||||
# Pattern 2: from module import ...
|
||||
elif re.match(rf"^from\s+{re.escape(module)}(\s|\.|$)", stripped):
|
||||
violation_lines.append((line_num, line))
|
||||
violated_modules.add(module)
|
||||
|
||||
# Pattern 3: from ... import module (less common but possible)
|
||||
elif re.search(
|
||||
rf"^from\s+[\w.]+\s+import\s+.*\b{re.escape(module)}\b",
|
||||
stripped,
|
||||
):
|
||||
violation_lines.append((line_num, line))
|
||||
violated_modules.add(module)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
|
||||
return EagerImportResult(
|
||||
violation_lines=violation_lines, violated_modules=violated_modules
|
||||
)
|
||||
|
||||
|
||||
def find_python_files(
|
||||
backend_dir: Path, ignore_directories: Set[str] | None = None
|
||||
) -> List[Path]:
|
||||
"""
|
||||
Find all Python files in the backend directory, excluding test files and ignored directories.
|
||||
|
||||
Args:
|
||||
backend_dir: Path to the backend directory to search
|
||||
ignore_directories: Set of directory names to ignore (e.g., {"model_server", "tests"})
|
||||
|
||||
Returns:
|
||||
List of Python file paths to check
|
||||
"""
|
||||
if ignore_directories is None:
|
||||
ignore_directories = set()
|
||||
|
||||
python_files = []
|
||||
for file_path in backend_dir.glob("**/*.py"):
|
||||
# Skip test files (they can contain test imports)
|
||||
path_parts = file_path.parts
|
||||
if (
|
||||
"tests" in path_parts
|
||||
or file_path.name.startswith("test_")
|
||||
or file_path.name.endswith("_test.py")
|
||||
):
|
||||
continue
|
||||
|
||||
# Skip ignored directories (check directory names, not file names)
|
||||
if any(ignored_dir in path_parts[:-1] for ignored_dir in ignore_directories):
|
||||
continue
|
||||
|
||||
python_files.append(file_path)
|
||||
|
||||
return python_files
|
||||
|
||||
|
||||
def main() -> None:
|
||||
backend_dir = Path(__file__).parent.parent # Go up from scripts/ to backend/
|
||||
|
||||
# Modules that should be imported lazily
|
||||
modules_to_lazy_import = {"vertexai", "nltk", "markitdown"}
|
||||
|
||||
ignore_directories = {"model_server"}
|
||||
|
||||
logger.info(
|
||||
f"Checking for direct imports of lazy modules: {', '.join(modules_to_lazy_import)}"
|
||||
)
|
||||
|
||||
# Find all Python files to check
|
||||
target_python_files = find_python_files(backend_dir, ignore_directories)
|
||||
|
||||
violations_found = False
|
||||
all_violated_modules = set()
|
||||
|
||||
# Check each Python file
|
||||
for file_path in target_python_files:
|
||||
result = find_eager_imports(file_path, modules_to_lazy_import)
|
||||
|
||||
if result.violation_lines:
|
||||
violations_found = True
|
||||
all_violated_modules.update(result.violated_modules)
|
||||
rel_path = file_path.relative_to(backend_dir)
|
||||
logger.error(f"\n❌ Eager import violations found in {rel_path}:")
|
||||
|
||||
for line_num, line in result.violation_lines:
|
||||
logger.error(f" Line {line_num}: {line.strip()}")
|
||||
|
||||
# Suggest fix only for violated modules
|
||||
if result.violated_modules:
|
||||
logger.error(
|
||||
f" 💡 You must import {', '.join(sorted(result.violated_modules))} only within functions when needed"
|
||||
)
|
||||
|
||||
if violations_found:
|
||||
violated_modules_str = ", ".join(sorted(all_violated_modules))
|
||||
raise RuntimeError(
|
||||
f"Found eager imports of {violated_modules_str}. You must import them only when needed."
|
||||
)
|
||||
else:
|
||||
logger.info("✅ All lazy modules are properly imported!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
sys.exit(0)
|
||||
except RuntimeError:
|
||||
sys.exit(1)
|
||||
0
backend/tests/unit/__init__.py
Normal file
0
backend/tests/unit/__init__.py
Normal file
0
backend/tests/unit/onyx/lazy_handling/__init__.py
Normal file
0
backend/tests/unit/onyx/lazy_handling/__init__.py
Normal file
0
backend/tests/unit/scripts/__init__.py
Normal file
0
backend/tests/unit/scripts/__init__.py
Normal file
447
backend/tests/unit/scripts/test_check_lazy_imports.py
Normal file
447
backend/tests/unit/scripts/test_check_lazy_imports.py
Normal file
@@ -0,0 +1,447 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.check_lazy_imports import EagerImportResult
|
||||
from scripts.check_lazy_imports import find_eager_imports
|
||||
from scripts.check_lazy_imports import find_python_files
|
||||
from scripts.check_lazy_imports import main
|
||||
|
||||
|
||||
def test_find_eager_imports_basic_violations() -> None:
|
||||
"""Test detection of basic eager import violations."""
|
||||
test_content = """
|
||||
import vertexai
|
||||
from vertexai import generative_models
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
import os # This should not be flagged
|
||||
from typing import Dict
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai", "transformers"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find 4 violations (lines 2, 3, 4, 5)
|
||||
assert len(result.violation_lines) == 4
|
||||
assert result.violated_modules == {"vertexai", "transformers"}
|
||||
|
||||
# Check specific violations
|
||||
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
|
||||
assert 2 in violation_line_numbers # import vertexai
|
||||
assert 3 in violation_line_numbers # from vertexai import generative_models
|
||||
assert 4 in violation_line_numbers # import transformers
|
||||
assert 5 in violation_line_numbers # from transformers import AutoTokenizer
|
||||
|
||||
# Lines 6 and 7 should not be flagged
|
||||
assert 6 not in violation_line_numbers # import os
|
||||
assert 7 not in violation_line_numbers # from typing import Dict
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_eager_imports_function_level_allowed() -> None:
|
||||
"""Test that imports inside functions are allowed (lazy imports)."""
|
||||
test_content = """import os
|
||||
|
||||
def some_function():
|
||||
import vertexai
|
||||
from transformers import AutoTokenizer
|
||||
return vertexai.some_method()
|
||||
|
||||
class MyClass:
|
||||
def method(self):
|
||||
import vertexai
|
||||
return vertexai.other_method()
|
||||
|
||||
# Top-level imports should be flagged
|
||||
import vertexai
|
||||
from transformers import BertModel
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai", "transformers"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should only find violations for top-level imports (lines 14, 15)
|
||||
assert len(result.violation_lines) == 2
|
||||
assert result.violated_modules == {"vertexai", "transformers"}
|
||||
|
||||
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
|
||||
assert 14 in violation_line_numbers # import vertexai (top-level)
|
||||
assert (
|
||||
15 in violation_line_numbers
|
||||
) # from transformers import BertModel (top-level)
|
||||
|
||||
# Function-level imports should not be flagged
|
||||
assert 4 not in violation_line_numbers # import vertexai (in function)
|
||||
assert (
|
||||
5 not in violation_line_numbers
|
||||
) # from transformers import AutoTokenizer (in function)
|
||||
assert 9 not in violation_line_numbers # import vertexai (in method)
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_eager_imports_complex_patterns() -> None:
|
||||
"""Test detection of various import patterns."""
|
||||
test_content = """
|
||||
import vertexai.generative_models # Should be flagged
|
||||
from some_package import vertexai # Should be flagged
|
||||
import vertexai_utils # Should not be flagged (different module)
|
||||
from vertexai_wrapper import something # Should not be flagged
|
||||
import myvertexai # Should not be flagged
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find 2 violations (lines 2, 3)
|
||||
assert len(result.violation_lines) == 2
|
||||
assert result.violated_modules == {"vertexai"}
|
||||
|
||||
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
|
||||
assert 2 in violation_line_numbers # import vertexai.generative_models
|
||||
assert 3 in violation_line_numbers # from some_package import vertexai
|
||||
|
||||
# Lines 4, 5, 6 should not be flagged
|
||||
assert 4 not in violation_line_numbers
|
||||
assert 5 not in violation_line_numbers
|
||||
assert 6 not in violation_line_numbers
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_eager_imports_comments_ignored() -> None:
|
||||
"""Test that commented imports are ignored."""
|
||||
test_content = """
|
||||
# import vertexai # This should be ignored
|
||||
import os
|
||||
# from vertexai import something # This should be ignored
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find no violations
|
||||
assert len(result.violation_lines) == 0
|
||||
assert result.violated_modules == set()
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_eager_imports_no_violations() -> None:
|
||||
"""Test file with no violations."""
|
||||
test_content = """
|
||||
import os
|
||||
from typing import Dict, List
|
||||
from pathlib import Path
|
||||
|
||||
def some_function():
|
||||
import vertexai
|
||||
return vertexai.some_method()
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai", "transformers"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find no violations
|
||||
assert len(result.violation_lines) == 0
|
||||
assert result.violated_modules == set()
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_eager_imports_file_read_error() -> None:
|
||||
"""Test handling of file read errors."""
|
||||
# Create a file path that will cause read errors
|
||||
nonexistent_path = Path("/nonexistent/path/test.py")
|
||||
|
||||
protected_modules = {"vertexai"}
|
||||
result = find_eager_imports(nonexistent_path, protected_modules)
|
||||
|
||||
# Should return empty result on error
|
||||
assert len(result.violation_lines) == 0
|
||||
assert result.violated_modules == set()
|
||||
|
||||
|
||||
def test_find_eager_imports_return_type() -> None:
|
||||
"""Test that function returns correct EagerImportResult type."""
|
||||
test_content = """
|
||||
import vertexai
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Check return type
|
||||
assert isinstance(result, EagerImportResult)
|
||||
assert hasattr(result, "violation_lines")
|
||||
assert hasattr(result, "violated_modules")
|
||||
assert len(result.violation_lines) == 1
|
||||
assert result.violated_modules == {"vertexai"}
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_main_function_no_violations(tmp_path: Path) -> None:
|
||||
"""Test main function with no violations."""
|
||||
# Create a temporary backend directory structure
|
||||
backend_dir = tmp_path / "backend"
|
||||
backend_dir.mkdir()
|
||||
|
||||
# Create a Python file with no violations (avoid "test" in name)
|
||||
clean_file = backend_dir / "clean_module.py"
|
||||
clean_file.write_text(
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def use_vertexai():
|
||||
import vertexai
|
||||
return vertexai.some_method()
|
||||
"""
|
||||
)
|
||||
|
||||
# Mock __file__ to point to our temporary structure
|
||||
script_path = backend_dir / "scripts" / "check_lazy_imports.py"
|
||||
script_path.parent.mkdir(parents=True)
|
||||
|
||||
with patch("scripts.check_lazy_imports.__file__", str(script_path)):
|
||||
# Should not raise an exception
|
||||
main()
|
||||
|
||||
|
||||
def test_main_function_with_violations(tmp_path: Path) -> None:
|
||||
"""Test main function with violations."""
|
||||
# Create a temporary backend directory structure
|
||||
backend_dir = tmp_path / "backend"
|
||||
backend_dir.mkdir()
|
||||
|
||||
# Create a Python file with violations (avoid "test" in name)
|
||||
violation_file = backend_dir / "violation_module.py"
|
||||
violation_file.write_text(
|
||||
"""
|
||||
import vertexai
|
||||
from transformers import AutoTokenizer
|
||||
"""
|
||||
)
|
||||
|
||||
# Mock __file__ to point to our temporary structure
|
||||
script_path = backend_dir / "scripts" / "check_lazy_imports.py"
|
||||
script_path.parent.mkdir(parents=True)
|
||||
|
||||
with patch("scripts.check_lazy_imports.__file__", str(script_path)):
|
||||
# Should raise RuntimeError due to violations
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Found eager imports of .+\\. You must import them only when needed",
|
||||
):
|
||||
main()
|
||||
|
||||
|
||||
def test_main_function_specific_modules_only() -> None:
|
||||
"""Test that only specific modules in protected list are flagged."""
|
||||
test_content = """
|
||||
import requests # Should not be flagged
|
||||
import vertexai # Should be flagged
|
||||
import transformers # Should be flagged
|
||||
import numpy # Should not be flagged
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai", "transformers"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should only flag vertexai and transformers
|
||||
assert len(result.violation_lines) == 2
|
||||
assert result.violated_modules == {"vertexai", "transformers"}
|
||||
|
||||
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
|
||||
assert 3 in violation_line_numbers # import vertexai
|
||||
assert 4 in violation_line_numbers # import transformers
|
||||
assert 2 not in violation_line_numbers # import requests (not protected)
|
||||
assert 5 not in violation_line_numbers # import numpy (not protected)
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_mixed_violations_and_clean_imports() -> None:
|
||||
"""Test files with both violations and allowed function-level imports."""
|
||||
test_content = """
|
||||
# Top-level violation
|
||||
import vertexai
|
||||
|
||||
import os # This is fine, not protected
|
||||
|
||||
def process_data():
|
||||
# Function-level import is allowed
|
||||
import vertexai
|
||||
from transformers import AutoTokenizer
|
||||
return "processed"
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self):
|
||||
# Method-level import is allowed
|
||||
import transformers
|
||||
self.model = transformers.AutoModel()
|
||||
|
||||
# Another top-level violation
|
||||
from transformers import pipeline
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||
f.write(test_content)
|
||||
test_path = Path(f.name)
|
||||
|
||||
try:
|
||||
protected_modules = {"vertexai", "transformers"}
|
||||
result = find_eager_imports(test_path, protected_modules)
|
||||
|
||||
# Should find 2 top-level violations
|
||||
assert len(result.violation_lines) == 2
|
||||
assert result.violated_modules == {"vertexai", "transformers"}
|
||||
|
||||
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
|
||||
assert 3 in violation_line_numbers # import vertexai (top-level)
|
||||
assert (
|
||||
20 in violation_line_numbers
|
||||
) # from transformers import pipeline (top-level)
|
||||
|
||||
# Function and method level imports should not be flagged
|
||||
assert 9 not in violation_line_numbers # import vertexai (in function)
|
||||
assert (
|
||||
10 not in violation_line_numbers
|
||||
) # from transformers import AutoTokenizer (in function)
|
||||
assert 16 not in violation_line_numbers # import transformers (in method)
|
||||
|
||||
finally:
|
||||
test_path.unlink()
|
||||
|
||||
|
||||
def test_find_python_files_basic() -> None:
|
||||
"""Test finding Python files with basic filtering."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_dir = Path(tmp_dir)
|
||||
|
||||
# Create various files
|
||||
(backend_dir / "normal.py").write_text("import os")
|
||||
(backend_dir / "test_file.py").write_text("import os") # Should be excluded
|
||||
(backend_dir / "file_test.py").write_text("import os") # Should be excluded
|
||||
|
||||
# Create subdirectories
|
||||
(backend_dir / "subdir").mkdir()
|
||||
(backend_dir / "subdir" / "normal.py").write_text("import os")
|
||||
|
||||
tests_dir = backend_dir / "tests"
|
||||
tests_dir.mkdir()
|
||||
(tests_dir / "test_something.py").write_text("import os") # Should be excluded
|
||||
|
||||
# Test with no ignore directories
|
||||
files = find_python_files(backend_dir, set())
|
||||
file_names = [f.name for f in files]
|
||||
|
||||
assert "normal.py" in file_names
|
||||
assert "test_file.py" not in file_names
|
||||
assert "file_test.py" not in file_names
|
||||
assert "test_something.py" not in file_names
|
||||
assert (
|
||||
len([f for f in files if f.name == "normal.py"]) == 2
|
||||
) # One in root, one in subdir
|
||||
|
||||
|
||||
def test_find_python_files_ignore_directories() -> None:
|
||||
"""Test finding Python files with ignored directories."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_dir = Path(tmp_dir)
|
||||
|
||||
# Create files in various directories
|
||||
(backend_dir / "normal.py").write_text("import os")
|
||||
|
||||
model_server_dir = backend_dir / "model_server"
|
||||
model_server_dir.mkdir()
|
||||
(model_server_dir / "model.py").write_text(
|
||||
"import transformers"
|
||||
) # Should be excluded
|
||||
|
||||
ignored_dir = backend_dir / "ignored"
|
||||
ignored_dir.mkdir()
|
||||
(ignored_dir / "should_be_ignored.py").write_text(
|
||||
"import vertexai"
|
||||
) # Should be excluded
|
||||
|
||||
# Create a file with ignored directory name in filename (should be included)
|
||||
(backend_dir / "model_server_utils.py").write_text("import os")
|
||||
|
||||
# Test with ignore directories
|
||||
files = find_python_files(backend_dir, {"model_server", "ignored"})
|
||||
file_names = [f.name for f in files]
|
||||
|
||||
assert "normal.py" in file_names
|
||||
assert (
|
||||
"model.py" not in file_names
|
||||
) # Excluded because in model_server directory
|
||||
assert (
|
||||
"should_be_ignored.py" not in file_names
|
||||
) # Excluded because in ignored directory
|
||||
assert (
|
||||
"model_server_utils.py" in file_names
|
||||
) # Included because not in directory, just filename
|
||||
|
||||
|
||||
def test_find_python_files_nested_ignore() -> None:
|
||||
"""Test that ignored directories work with nested paths."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_dir = Path(tmp_dir)
|
||||
|
||||
# Create nested structure
|
||||
nested_path = backend_dir / "some" / "path" / "model_server" / "nested"
|
||||
nested_path.mkdir(parents=True)
|
||||
(nested_path / "deep_model.py").write_text("import transformers")
|
||||
|
||||
files = find_python_files(backend_dir, {"model_server"})
|
||||
|
||||
# Should exclude the deeply nested file
|
||||
assert len(files) == 0
|
||||
Reference in New Issue
Block a user