Skip to content

Commit

Permalink
Add utility for Reload Transformers imports cache for development wor…
Browse files Browse the repository at this point in the history
…kflow huggingface#35508 (huggingface#35858)

* Reload transformers fix form cache

* add imports

* add test fn for clearing import cache

* ruff fix to core import logic

* ruff fix to test file

* fixup for imports

* fixup for test

* lru restore

* test check

* fix style changes

* added documentation for usecase

* fixing

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
  • Loading branch information
2 people authored and sbucaille committed Feb 14, 2025
1 parent a2176d8 commit 99cca44
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 1 deletion.
32 changes: 31 additions & 1 deletion docs/source/en/how_to_hack_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,37 @@ You'll learn how to:
- Modify a model's architecture by changing its attention mechanism.
- Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.

We encourage you to contribute your own hacks and share them here with the community1
We encourage you to contribute your own hacks and share them here with the community!

## Efficient Development Workflow

When modifying model code, you'll often need to test your changes without restarting your Python session. The `clear_import_cache()` utility helps with this workflow, especially during model development and contribution when you need to frequently test and compare model outputs:

```python
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-uncased")

# Make modifications to the transformers code...

# Clear the cache to reload the modified code
from transformers.utils.import_utils import clear_import_cache
clear_import_cache()

# Reimport to get the changes
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-uncased") # Will use updated code
```

This is particularly useful when:
- Iteratively modifying model architectures
- Debugging model implementations
- Testing changes during model development
- Comparing outputs between original and modified versions
- Working on model contributions

The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting your environment.

This workflow is especially valuable when implementing new models, where you need to frequently compare outputs between the original implementation and your Transformers version (as described in the [Add New Model](https://huggingface.co/docs/transformers/add_new_model) guide).

## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)

Expand Down
25 changes: 25 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,3 +2276,28 @@ def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T:
"""
import_structure = create_import_structure_from_path(module_path)
return spread_import_structure(import_structure)


def clear_import_cache():
"""
Clear cached Transformers modules to allow reloading modified code.
This is useful when actively developing/modifying Transformers code.
"""
# Get all transformers modules
transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]

# Remove them from sys.modules
for mod_name in transformers_modules:
module = sys.modules[mod_name]
# Clear _LazyModule caches if applicable
if isinstance(module, _LazyModule):
module._objects = {} # Clear cached objects
del sys.modules[mod_name]

# Force reload main transformers module
if "transformers" in sys.modules:
main_module = sys.modules["transformers"]
if isinstance(main_module, _LazyModule):
main_module._objects = {} # Clear cached objects
importlib.reload(main_module)
23 changes: 23 additions & 0 deletions tests/utils/test_import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import sys

from transformers.utils.import_utils import clear_import_cache


def test_clear_import_cache():
# Import some transformers modules

# Get initial module count
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}

# Verify we have some modules loaded
assert len(initial_modules) > 0

# Clear cache
clear_import_cache()

# Check modules were removed
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
assert len(remaining_modules) < len(initial_modules)

# Verify we can reimport
assert "transformers.models.auto.modeling_auto" in sys.modules

0 comments on commit 99cca44

Please sign in to comment.