Skip to content

Commit

Permalink
Fix quant unpatch (#1129)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced unpatching functionality for quantized models, enhancing
model management.
- Added quantization tracking to the OneflowDeployableModule, improving
configuration capabilities.
	- Implemented a method to check the quantization state of the module.

- **Bug Fixes**
- Improved handling of unpatching conditions for quantized models to
prevent unintended skips.

- **Documentation**
- Updated documentation for the apply_online_quant method with usage
examples.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
strint and coderabbitai[bot] authored Oct 31, 2024
1 parent 177dfc0 commit 445a61e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
20 changes: 20 additions & 0 deletions onediff_comfy_nodes/modules/oneflow/hijack_model_patcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from comfy.model_patcher import ModelPatcher
from onediff.utils import logger

from ..sd_hijack_utils import Hijacker
from .patch_management import create_patch_executor, PatchType
Expand All @@ -18,6 +19,25 @@ def cond_func(org_fn, self):
return is_using_oneflow_backend(self)


def unpatch_model_oneflow(org_fn, self, device_to=None, unpatch_weights=True):
if unpatch_weights:
logger.warning(
f"{type(self.model.diffusion_model)} is quantized by onediff, so unpatching is skipped."
)
return


def unpatch_model_cond_func(org_fn, self, *args, **kwargs):
if hasattr(self.model, "diffusion_model") and hasattr(
self.model.diffusion_model, "_deployable_module_quantized"
):
return self.model.diffusion_model._deployable_module_quantized
return False


model_patch_hijacker = Hijacker()

model_patch_hijacker.register(ModelPatcher.clone, clone_oneflow, cond_func)
model_patch_hijacker.register(
ModelPatcher.unpatch_model, unpatch_model_oneflow, unpatch_model_cond_func
)
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def apply_online_quant(self, quant_config):
>>> model.apply_online_quant(quant_config)
"""
self._deployable_module_quant_config = quant_config
self._deployable_module_quantized = True


def get_mixed_deployable_module(module_cls):
Expand Down

0 comments on commit 445a61e

Please sign in to comment.