Compare commits

...

25 Commits

Author SHA1 Message Date
edwin-onyx
850262cd39 Merge branch 'main' into edwin/dan-2573 2025-09-19 16:15:46 -07:00
edwin
6afa54d346 . 2025-09-19 13:32:17 -07:00
edwin
c5ef00c739 . 2025-09-19 13:31:01 -07:00
edwin
e8d62bb259 . 2025-09-19 13:27:02 -07:00
edwin
f3b11f57ac . 2025-09-19 13:23:49 -07:00
edwin
4724fd7adf . 2025-09-19 12:00:46 -07:00
edwin
f4f03cc282 . 2025-09-19 11:53:49 -07:00
edwin
3bcff943d1 . 2025-09-19 11:49:53 -07:00
edwin
f430f4892e . 2025-09-19 11:49:34 -07:00
edwin
26cc8f10b9 . 2025-09-19 09:37:37 -07:00
edwin
559ddf31cc . 2025-09-19 09:15:16 -07:00
edwin
9e8acdc9bb . 2025-09-19 09:12:07 -07:00
edwin-onyx
2cd20a1cbc Merge branch 'main' into edwin/dan-2558 2025-09-18 19:54:14 -07:00
edwin
3e0f5f3a21 . 2025-09-18 19:40:17 -07:00
edwin
565ef1e584 . 2025-09-18 19:06:14 -07:00
edwin
579a2936fb . 2025-09-18 18:54:38 -07:00
edwin
526f457545 . 2025-09-18 18:45:24 -07:00
edwin
c1fac15b67 . 2025-09-18 18:40:56 -07:00
edwin
e253b69dc6 Remove Celery memory measurement script
Moving this script to a separate feature branch for better organization.
The script will be available in the feature/celery-memory-measurement branch.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-09-18 18:26:24 -07:00
edwin
a4607a05c3 . 2025-09-18 18:20:15 -07:00
edwin
0a5c7d8edc . 2025-09-18 17:51:39 -07:00
edwin
00c898c0ee . 2025-09-18 17:32:58 -07:00
edwin
dfd1ff533a . 2025-09-18 16:51:16 -07:00
edwin
cf032d8ba6 . 2025-09-18 16:37:15 -07:00
edwin
7c94ce37f9 . 2025-09-18 16:00:51 -07:00
12 changed files with 662 additions and 19 deletions

View File

@@ -37,6 +37,15 @@ repos:
additional_dependencies:
- prettier
- repo: local
hooks:
- id: check-lazy-imports
name: Check lazy imports are not directly imported
entry: python3 backend/scripts/check_lazy_imports.py
language: system
files: ^backend/.*\.py$
pass_filenames: false
# We would like to have a mypy pre-commit hook, but due to the fact that
# pre-commit runs in it's own isolated environment, we would need to install
# and keep in sync all dependencies so mypy has access to the appropriate type

View File

@@ -2,7 +2,6 @@ import string
from collections.abc import Callable
from uuid import UUID
import nltk # type:ignore
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
@@ -60,6 +59,8 @@ def _dedupe_chunks(
def download_nltk_data() -> None:
import nltk
resources = {
"stopwords": "corpora/stopwords",
# "wordnet": "corpora/wordnet", # Not in use
@@ -166,7 +167,6 @@ def doc_index_retrieval(
and query.expanded_queries.keywords_expansions
and query.expanded_queries.semantic_expansions
):
keyword_embeddings_thread = run_in_background(
get_query_embeddings,
query.expanded_queries.keywords_expansions,
@@ -233,13 +233,11 @@ def doc_index_retrieval(
# use all three retrieval methods to retrieve top chunks
if query.search_type == SearchType.SEMANTIC and top_semantic_chunks is not None:
all_top_chunks += top_semantic_chunks
top_chunks = _dedupe_chunks(all_top_chunks)
else:
top_base_chunks_standard_ranking = wait_on_background(
top_base_chunks_standard_ranking_thread
)
@@ -395,8 +393,7 @@ def retrieve_chunks(
if not top_chunks:
logger.warning(
f"Hybrid ({query.search_type.value.capitalize()}) search returned no results "
f"with filters: {query.filters}"
f"Hybrid ({query.search_type.value.capitalize()}) search returned no results with filters: {query.filters}"
)
return []

View File

@@ -2,8 +2,6 @@ import string
from collections.abc import Sequence
from typing import TypeVar
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from onyx.chat.models import SectionRelevancePiece
@@ -153,6 +151,9 @@ def chunks_or_sections_to_search_docs(
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
from nltk.corpus import stopwords # type: ignore[import-untyped]
from nltk.tokenize import word_tokenize # type: ignore[import-untyped]
try:
# Re-tokenize using the NLTK tokenizer for better matching
query = " ".join(keywords)

View File

@@ -17,9 +17,6 @@ from typing import NamedTuple
from zipfile import BadZipFile
import chardet
from markitdown import FileConversionException
from markitdown import MarkItDown
from markitdown import UnsupportedFormatException
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
@@ -330,6 +327,12 @@ def docx_to_text_and_images(
file_name: str = "",
image_callback: Callable[[bytes, str], None] | None = None,
) -> tuple[str, Sequence[tuple[bytes, str]]]:
from markitdown import (
FileConversionException,
UnsupportedFormatException,
MarkItDown,
)
"""
Extract text from a docx.
Return (text_content, list_of_images).
@@ -372,6 +375,12 @@ def docx_to_text_and_images(
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
from markitdown import (
FileConversionException,
UnsupportedFormatException,
MarkItDown,
)
md = MarkItDown(enable_plugins=False)
try:
presentation = md.convert(to_bytesio(file))
@@ -388,6 +397,12 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
from markitdown import (
FileConversionException,
UnsupportedFormatException,
MarkItDown,
)
md = MarkItDown(enable_plugins=False)
try:
workbook = md.convert(to_bytesio(file))

View File

@@ -3,7 +3,6 @@ from collections import defaultdict
from typing import cast
import numpy as np
from nltk import ngrams # type: ignore
from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity
from sqlalchemy import desc
from sqlalchemy import Float
@@ -59,6 +58,8 @@ def _normalize_one_entity(
attributes: dict[str, str],
allowed_docs_temp_view_name: str | None = None,
) -> str | None:
from nltk import ngrams
"""
Matches a single entity to the best matching entity of the same type.
"""
@@ -77,7 +78,6 @@ def _normalize_one_entity(
# step 1: find entities containing the entity_name or something similar
with get_session_with_current_tenant() as db_session:
# get allowed documents
metadata = MetaData()
if allowed_docs_temp_view_name is None:
@@ -257,7 +257,6 @@ def normalize_entities(
for entity, attributes, normalized_entity in zip(
raw_entities, entity_attributes, mapping
):
if normalized_entity is not None:
normalized_entities.append(normalized_entity)
normalized_entities_w_attributes.append(

View File

@@ -15,7 +15,6 @@ import aioboto3 # type: ignore
import httpx
import openai
import requests
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from google.oauth2 import service_account # type: ignore
@@ -25,8 +24,6 @@ from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
@@ -266,6 +263,9 @@ class CloudEmbedding:
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
import vertexai # type: ignore[import-untyped]
from vertexai.language_models import TextEmbeddingModel, TextEmbeddingInput # type: ignore[import-untyped]
if not model:
model = DEFAULT_VERTEX_MODEL
@@ -551,8 +551,7 @@ class EmbeddingModel:
if embed_request.manual_query_prefix or embed_request.manual_passage_prefix:
logger.warning("Prefix provided for cloud model, which is not supported")
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
"Prefix string is not valid for cloud models. Cloud models take an explicit text type instead."
)
if not all(embed_request.texts):

View File

View File

@@ -0,0 +1,176 @@
import logging
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List
from typing import Set
# Configure the logger
logging.basicConfig(
level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format
handlers=[logging.StreamHandler()], # Output logs to console
)
logger = logging.getLogger(__name__)
@dataclass
class EagerImportResult:
"""Result of checking a file for eager imports."""
violation_lines: List[tuple[int, str]] # (line_number, line_content) tuples
violated_modules: Set[str] # modules that were actually violated
def find_eager_imports(
file_path: Path, protected_modules: Set[str]
) -> EagerImportResult:
"""
Find eager imports of protected modules in a given file.
Eager imports are top-level (module-level) imports that happen immediately
when the module is loaded, as opposed to lazy imports that happen inside
functions only when called.
Args:
file_path: Path to Python file to check
protected_modules: Set of module names that should only be imported lazily
Returns:
EagerImportResult containing violations list and violated modules set
"""
violation_lines = []
violated_modules = set()
try:
content = file_path.read_text(encoding="utf-8")
lines = content.split("\n")
for line_num, line in enumerate(lines, 1):
stripped = line.strip()
# Skip comments and empty lines
if not stripped or stripped.startswith("#"):
continue
# Only check imports at module level (indentation == 0)
current_indent = len(line) - len(line.lstrip())
if current_indent == 0:
# Check for eager imports of protected modules
for module in protected_modules:
# Pattern 1: import module
if re.match(rf"^import\s+{re.escape(module)}(\s|$|\.)", stripped):
violation_lines.append((line_num, line))
violated_modules.add(module)
# Pattern 2: from module import ...
elif re.match(rf"^from\s+{re.escape(module)}(\s|\.|$)", stripped):
violation_lines.append((line_num, line))
violated_modules.add(module)
# Pattern 3: from ... import module (less common but possible)
elif re.search(
rf"^from\s+[\w.]+\s+import\s+.*\b{re.escape(module)}\b",
stripped,
):
violation_lines.append((line_num, line))
violated_modules.add(module)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return EagerImportResult(
violation_lines=violation_lines, violated_modules=violated_modules
)
def find_python_files(
backend_dir: Path, ignore_directories: Set[str] | None = None
) -> List[Path]:
"""
Find all Python files in the backend directory, excluding test files and ignored directories.
Args:
backend_dir: Path to the backend directory to search
ignore_directories: Set of directory names to ignore (e.g., {"model_server", "tests"})
Returns:
List of Python file paths to check
"""
if ignore_directories is None:
ignore_directories = set()
python_files = []
for file_path in backend_dir.glob("**/*.py"):
# Skip test files (they can contain test imports)
path_parts = file_path.parts
if (
"tests" in path_parts
or file_path.name.startswith("test_")
or file_path.name.endswith("_test.py")
):
continue
# Skip ignored directories (check directory names, not file names)
if any(ignored_dir in path_parts[:-1] for ignored_dir in ignore_directories):
continue
python_files.append(file_path)
return python_files
def main() -> None:
backend_dir = Path(__file__).parent.parent # Go up from scripts/ to backend/
# Modules that should be imported lazily
modules_to_lazy_import = {"vertexai", "nltk", "markitdown"}
ignore_directories = {"model_server"}
logger.info(
f"Checking for direct imports of lazy modules: {', '.join(modules_to_lazy_import)}"
)
# Find all Python files to check
target_python_files = find_python_files(backend_dir, ignore_directories)
violations_found = False
all_violated_modules = set()
# Check each Python file
for file_path in target_python_files:
result = find_eager_imports(file_path, modules_to_lazy_import)
if result.violation_lines:
violations_found = True
all_violated_modules.update(result.violated_modules)
rel_path = file_path.relative_to(backend_dir)
logger.error(f"\n❌ Eager import violations found in {rel_path}:")
for line_num, line in result.violation_lines:
logger.error(f" Line {line_num}: {line.strip()}")
# Suggest fix only for violated modules
if result.violated_modules:
logger.error(
f" 💡 You must import {', '.join(sorted(result.violated_modules))} only within functions when needed"
)
if violations_found:
violated_modules_str = ", ".join(sorted(all_violated_modules))
raise RuntimeError(
f"Found eager imports of {violated_modules_str}. You must import them only when needed."
)
else:
logger.info("✅ All lazy modules are properly imported!")
if __name__ == "__main__":
try:
main()
sys.exit(0)
except RuntimeError:
sys.exit(1)

View File

View File

View File

@@ -0,0 +1,447 @@
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from scripts.check_lazy_imports import EagerImportResult
from scripts.check_lazy_imports import find_eager_imports
from scripts.check_lazy_imports import find_python_files
from scripts.check_lazy_imports import main
def test_find_eager_imports_basic_violations() -> None:
"""Test detection of basic eager import violations."""
test_content = """
import vertexai
from vertexai import generative_models
import transformers
from transformers import AutoTokenizer
import os # This should not be flagged
from typing import Dict
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai", "transformers"}
result = find_eager_imports(test_path, protected_modules)
# Should find 4 violations (lines 2, 3, 4, 5)
assert len(result.violation_lines) == 4
assert result.violated_modules == {"vertexai", "transformers"}
# Check specific violations
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
assert 2 in violation_line_numbers # import vertexai
assert 3 in violation_line_numbers # from vertexai import generative_models
assert 4 in violation_line_numbers # import transformers
assert 5 in violation_line_numbers # from transformers import AutoTokenizer
# Lines 6 and 7 should not be flagged
assert 6 not in violation_line_numbers # import os
assert 7 not in violation_line_numbers # from typing import Dict
finally:
test_path.unlink()
def test_find_eager_imports_function_level_allowed() -> None:
"""Test that imports inside functions are allowed (lazy imports)."""
test_content = """import os
def some_function():
import vertexai
from transformers import AutoTokenizer
return vertexai.some_method()
class MyClass:
def method(self):
import vertexai
return vertexai.other_method()
# Top-level imports should be flagged
import vertexai
from transformers import BertModel
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai", "transformers"}
result = find_eager_imports(test_path, protected_modules)
# Should only find violations for top-level imports (lines 14, 15)
assert len(result.violation_lines) == 2
assert result.violated_modules == {"vertexai", "transformers"}
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
assert 14 in violation_line_numbers # import vertexai (top-level)
assert (
15 in violation_line_numbers
) # from transformers import BertModel (top-level)
# Function-level imports should not be flagged
assert 4 not in violation_line_numbers # import vertexai (in function)
assert (
5 not in violation_line_numbers
) # from transformers import AutoTokenizer (in function)
assert 9 not in violation_line_numbers # import vertexai (in method)
finally:
test_path.unlink()
def test_find_eager_imports_complex_patterns() -> None:
"""Test detection of various import patterns."""
test_content = """
import vertexai.generative_models # Should be flagged
from some_package import vertexai # Should be flagged
import vertexai_utils # Should not be flagged (different module)
from vertexai_wrapper import something # Should not be flagged
import myvertexai # Should not be flagged
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai"}
result = find_eager_imports(test_path, protected_modules)
# Should find 2 violations (lines 2, 3)
assert len(result.violation_lines) == 2
assert result.violated_modules == {"vertexai"}
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
assert 2 in violation_line_numbers # import vertexai.generative_models
assert 3 in violation_line_numbers # from some_package import vertexai
# Lines 4, 5, 6 should not be flagged
assert 4 not in violation_line_numbers
assert 5 not in violation_line_numbers
assert 6 not in violation_line_numbers
finally:
test_path.unlink()
def test_find_eager_imports_comments_ignored() -> None:
"""Test that commented imports are ignored."""
test_content = """
# import vertexai # This should be ignored
import os
# from vertexai import something # This should be ignored
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai"}
result = find_eager_imports(test_path, protected_modules)
# Should find no violations
assert len(result.violation_lines) == 0
assert result.violated_modules == set()
finally:
test_path.unlink()
def test_find_eager_imports_no_violations() -> None:
"""Test file with no violations."""
test_content = """
import os
from typing import Dict, List
from pathlib import Path
def some_function():
import vertexai
return vertexai.some_method()
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai", "transformers"}
result = find_eager_imports(test_path, protected_modules)
# Should find no violations
assert len(result.violation_lines) == 0
assert result.violated_modules == set()
finally:
test_path.unlink()
def test_find_eager_imports_file_read_error() -> None:
"""Test handling of file read errors."""
# Create a file path that will cause read errors
nonexistent_path = Path("/nonexistent/path/test.py")
protected_modules = {"vertexai"}
result = find_eager_imports(nonexistent_path, protected_modules)
# Should return empty result on error
assert len(result.violation_lines) == 0
assert result.violated_modules == set()
def test_find_eager_imports_return_type() -> None:
"""Test that function returns correct EagerImportResult type."""
test_content = """
import vertexai
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai"}
result = find_eager_imports(test_path, protected_modules)
# Check return type
assert isinstance(result, EagerImportResult)
assert hasattr(result, "violation_lines")
assert hasattr(result, "violated_modules")
assert len(result.violation_lines) == 1
assert result.violated_modules == {"vertexai"}
finally:
test_path.unlink()
def test_main_function_no_violations(tmp_path: Path) -> None:
"""Test main function with no violations."""
# Create a temporary backend directory structure
backend_dir = tmp_path / "backend"
backend_dir.mkdir()
# Create a Python file with no violations (avoid "test" in name)
clean_file = backend_dir / "clean_module.py"
clean_file.write_text(
"""
import os
from pathlib import Path
def use_vertexai():
import vertexai
return vertexai.some_method()
"""
)
# Mock __file__ to point to our temporary structure
script_path = backend_dir / "scripts" / "check_lazy_imports.py"
script_path.parent.mkdir(parents=True)
with patch("scripts.check_lazy_imports.__file__", str(script_path)):
# Should not raise an exception
main()
def test_main_function_with_violations(tmp_path: Path) -> None:
"""Test main function with violations."""
# Create a temporary backend directory structure
backend_dir = tmp_path / "backend"
backend_dir.mkdir()
# Create a Python file with violations (avoid "test" in name)
violation_file = backend_dir / "violation_module.py"
violation_file.write_text(
"""
import vertexai
from transformers import AutoTokenizer
"""
)
# Mock __file__ to point to our temporary structure
script_path = backend_dir / "scripts" / "check_lazy_imports.py"
script_path.parent.mkdir(parents=True)
with patch("scripts.check_lazy_imports.__file__", str(script_path)):
# Should raise RuntimeError due to violations
with pytest.raises(
RuntimeError,
match="Found eager imports of .+\\. You must import them only when needed",
):
main()
def test_main_function_specific_modules_only() -> None:
"""Test that only specific modules in protected list are flagged."""
test_content = """
import requests # Should not be flagged
import vertexai # Should be flagged
import transformers # Should be flagged
import numpy # Should not be flagged
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai", "transformers"}
result = find_eager_imports(test_path, protected_modules)
# Should only flag vertexai and transformers
assert len(result.violation_lines) == 2
assert result.violated_modules == {"vertexai", "transformers"}
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
assert 3 in violation_line_numbers # import vertexai
assert 4 in violation_line_numbers # import transformers
assert 2 not in violation_line_numbers # import requests (not protected)
assert 5 not in violation_line_numbers # import numpy (not protected)
finally:
test_path.unlink()
def test_mixed_violations_and_clean_imports() -> None:
"""Test files with both violations and allowed function-level imports."""
test_content = """
# Top-level violation
import vertexai
import os # This is fine, not protected
def process_data():
# Function-level import is allowed
import vertexai
from transformers import AutoTokenizer
return "processed"
class DataProcessor:
def __init__(self):
# Method-level import is allowed
import transformers
self.model = transformers.AutoModel()
# Another top-level violation
from transformers import pipeline
"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(test_content)
test_path = Path(f.name)
try:
protected_modules = {"vertexai", "transformers"}
result = find_eager_imports(test_path, protected_modules)
# Should find 2 top-level violations
assert len(result.violation_lines) == 2
assert result.violated_modules == {"vertexai", "transformers"}
violation_line_numbers = [line_num for line_num, _ in result.violation_lines]
assert 3 in violation_line_numbers # import vertexai (top-level)
assert (
20 in violation_line_numbers
) # from transformers import pipeline (top-level)
# Function and method level imports should not be flagged
assert 9 not in violation_line_numbers # import vertexai (in function)
assert (
10 not in violation_line_numbers
) # from transformers import AutoTokenizer (in function)
assert 16 not in violation_line_numbers # import transformers (in method)
finally:
test_path.unlink()
def test_find_python_files_basic() -> None:
"""Test finding Python files with basic filtering."""
with tempfile.TemporaryDirectory() as tmp_dir:
backend_dir = Path(tmp_dir)
# Create various files
(backend_dir / "normal.py").write_text("import os")
(backend_dir / "test_file.py").write_text("import os") # Should be excluded
(backend_dir / "file_test.py").write_text("import os") # Should be excluded
# Create subdirectories
(backend_dir / "subdir").mkdir()
(backend_dir / "subdir" / "normal.py").write_text("import os")
tests_dir = backend_dir / "tests"
tests_dir.mkdir()
(tests_dir / "test_something.py").write_text("import os") # Should be excluded
# Test with no ignore directories
files = find_python_files(backend_dir, set())
file_names = [f.name for f in files]
assert "normal.py" in file_names
assert "test_file.py" not in file_names
assert "file_test.py" not in file_names
assert "test_something.py" not in file_names
assert (
len([f for f in files if f.name == "normal.py"]) == 2
) # One in root, one in subdir
def test_find_python_files_ignore_directories() -> None:
"""Test finding Python files with ignored directories."""
with tempfile.TemporaryDirectory() as tmp_dir:
backend_dir = Path(tmp_dir)
# Create files in various directories
(backend_dir / "normal.py").write_text("import os")
model_server_dir = backend_dir / "model_server"
model_server_dir.mkdir()
(model_server_dir / "model.py").write_text(
"import transformers"
) # Should be excluded
ignored_dir = backend_dir / "ignored"
ignored_dir.mkdir()
(ignored_dir / "should_be_ignored.py").write_text(
"import vertexai"
) # Should be excluded
# Create a file with ignored directory name in filename (should be included)
(backend_dir / "model_server_utils.py").write_text("import os")
# Test with ignore directories
files = find_python_files(backend_dir, {"model_server", "ignored"})
file_names = [f.name for f in files]
assert "normal.py" in file_names
assert (
"model.py" not in file_names
) # Excluded because in model_server directory
assert (
"should_be_ignored.py" not in file_names
) # Excluded because in ignored directory
assert (
"model_server_utils.py" in file_names
) # Included because not in directory, just filename
def test_find_python_files_nested_ignore() -> None:
"""Test that ignored directories work with nested paths."""
with tempfile.TemporaryDirectory() as tmp_dir:
backend_dir = Path(tmp_dir)
# Create nested structure
nested_path = backend_dir / "some" / "path" / "model_server" / "nested"
nested_path.mkdir(parents=True)
(nested_path / "deep_model.py").write_text("import transformers")
files = find_python_files(backend_dir, {"model_server"})
# Should exclude the deeply nested file
assert len(files) == 0