diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index cc0abac6e4ab..9375f760a151 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -359,6 +359,8 @@ jobs: test_location: "bnb" - backend: "gguf" test_location: "gguf" + - backend: "torchao" + test_location: "torchao" runs-on: group: aws-g6e-xlarge-plus container: diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md index 1f9f99a79a3b..c056876c2f09 100644 --- a/docs/source/en/quantization/torchao.md +++ b/docs/source/en/quantization/torchao.md @@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] The example below only quantizes the weights to int8. ```python +import torch from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig model_id = "black-forest-labs/FLUX.1-dev" @@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained( ) pipe.to("cuda") +# Without quantization: ~31.447 GB +# With quantization: ~20.40 GB +print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") + prompt = "A cat holding a sign that says hello world" image = pipe( prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 @@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available. +## Serializing and Deserializing quantized models + +To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method. + +```python +import torch +from diffusers import FluxTransformer2DModel, TorchAoConfig + +quantization_config = TorchAoConfig("int8wo") +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, +) +transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False) +``` + +To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method. + +```python +import torch +from diffusers import FluxPipeline, FluxTransformer2DModel + +transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False) +pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0] +image.save("output.png") +``` + +Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source. + +```python +import torch +from accelerate import init_empty_weights +from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig + +# Serialize the model +transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/Flux.1-Dev", + subfolder="transformer", + quantization_config=TorchAoConfig("uint4wo"), + torch_dtype=torch.bfloat16, +) +transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB") +# ... + +# Load the model +state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu") +with init_empty_weights(): + transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json") +transformer.load_state_dict(state_dict, strict=True, assign=True) +``` + ## Resources - [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d236ebb83983..d6efcc736487 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -718,10 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P hf_quantizer = None if hf_quantizer is not None: - is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes" - if is_bnb_quantization_method and device_map is not None: + if device_map is not None: raise NotImplementedError( - "Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future." + "Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future." ) hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) @@ -820,7 +819,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) - if hf_quantizer is not None and is_bnb_quantization_method: + # TODO: https://github.com/huggingface/diffusers/issues/10013 + if hf_quantizer is not None: model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py index 5770e32c909e..a829234afd56 100644 --- a/src/diffusers/quantizers/torchao/torchao_quantizer.py +++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py @@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype): quant_type = self.quantization_config.quant_type - if quant_type.startswith("int"): + if quant_type.startswith("int") or quant_type.startswith("uint"): if torch_dtype is not None and torch_dtype != torch.bfloat16: logger.warning( f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0fa9182a3314..3c3f13db9b1c 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -131,8 +131,9 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def get_dummy_components(self, quantization_config: TorchAoConfig): - model_id = "hf-internal-testing/tiny-flux-pipe" + def get_dummy_components( + self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" + ): transformer = FluxTransformer2DModel.from_pretrained( model_id, subfolder="transformer", @@ -211,8 +212,8 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0): "timestep": timestep, } - def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]): - components = self.get_dummy_components(quantization_config) + def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str): + components = self.get_dummy_components(quantization_config, model_id) pipe = FluxPipeline(**components) pipe.to(device=torch_device) @@ -223,44 +224,45 @@ def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: L self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): - # fmt: off - QUANTIZATION_TYPES_TO_TEST = [ - ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), - ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), - ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ] - - if TorchAoConfig._is_cuda_capability_atleast_8_9(): - QUANTIZATION_TYPES_TO_TEST.extend([ - ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), - ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), - # ===== - # The following lead to an internal torch error: - # RuntimeError: mat2 shape (32x4 must be divisible by 16 - # Skip these for now; TODO(aryan): investigate later - # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - # Cutlass fails to initialize for below - # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), - # ===== - ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), - ]) - # fmt: on - - for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: - quant_kwargs = {} - if quantization_name in ["uint4wo", "uint7wo"]: - # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here - quant_kwargs.update({"group_size": 16}) - quantization_config = TorchAoConfig( - quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs - ) - self._test_quant_type(quantization_config, expected_slice) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + # fmt: off + QUANTIZATION_TYPES_TO_TEST = [ + ("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), + ("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), + ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), + ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ] + + if TorchAoConfig._is_cuda_capability_atleast_8_9(): + QUANTIZATION_TYPES_TO_TEST.extend([ + ("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), + ("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), + # ===== + # The following lead to an internal torch error: + # RuntimeError: mat2 shape (32x4 must be divisible by 16 + # Skip these for now; TODO(aryan): investigate later + # ("float8dq_e4m3", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ("float8dq_e4m3_tensor", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + # Cutlass fails to initialize for below + # ("float8dq_e4m3_row", np.array([0, 0, 0, 0, 0, 0, 0, 0, 0])), + # ===== + ("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), + ]) + # fmt: on + + for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: + quant_kwargs = {} + if quantization_name in ["uint4wo", "uint7wo"]: + # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here + quant_kwargs.update({"group_size": 16}) + quantization_config = TorchAoConfig( + quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs + ) + self._test_quant_type(quantization_config, expected_slice, model_id) def test_int4wo_quant_bfloat16_conversion(self): """ @@ -280,12 +282,14 @@ def test_int4wo_quant_bfloat16_conversion(self): self.assertEqual(weight.quant_max, 15) def test_device_map(self): + # Note: We were not checking if the weight tensor's were AffineQuantizedTensor's before. If we did + # it would have errored out. Now, we do. So, device_map basically never worked with or without + # sharded checkpoints. This will need to be supported in the future (TODO(aryan)) """ Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. The custom device map performs cpu/disk offloading as well. Also verifies that the device map is correctly set (in the `hf_device_map` attribute of the model). """ - custom_device_map_dict = { "time_text_embed": torch_device, "context_embedder": torch_device, @@ -297,48 +301,54 @@ def test_device_map(self): } device_maps = ["auto", custom_device_map_dict] - inputs = self.get_dummy_tensor_inputs(torch_device) - expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) + # inputs = self.get_dummy_tensor_inputs(torch_device) + # expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) for device_map in device_maps: - device_map_to_compare = {"": 0} if device_map == "auto" else device_map - - # Test non-sharded model - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) - - self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) - - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) - - # Test sharded model - with tempfile.TemporaryDirectory() as offload_folder: - quantization_config = TorchAoConfig("int4_weight_only", group_size=64) - quantized_model = FluxTransformer2DModel.from_pretrained( - "hf-internal-testing/tiny-flux-sharded", - subfolder="transformer", - quantization_config=quantization_config, - device_map=device_map, - torch_dtype=torch.bfloat16, - offload_folder=offload_folder, - ) - - self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) - - output = quantized_model(**inputs)[0] - output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + # device_map_to_compare = {"": 0} if device_map == "auto" else device_map + + # Test non-sharded model - should work + with self.assertRaises(NotImplementedError): + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + # weight = quantized_model.transformer_blocks[0].ff.net[2].weight + # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + # output = quantized_model(**inputs)[0] + # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) + + # Test sharded model - should not work + with self.assertRaises(NotImplementedError): + with tempfile.TemporaryDirectory() as offload_folder: + quantization_config = TorchAoConfig("int4_weight_only", group_size=64) + _ = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/tiny-flux-sharded", + subfolder="transformer", + quantization_config=quantization_config, + device_map=device_map, + torch_dtype=torch.bfloat16, + offload_folder=offload_folder, + ) + + # weight = quantized_model.transformer_blocks[0].ff.net[2].weight + # self.assertTrue(quantized_model.hf_device_map == device_map_to_compare) + # self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + # output = quantized_model(**inputs)[0] + # output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + # self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) @@ -404,43 +414,63 @@ def test_training(self): @nightly def test_torch_compile(self): r"""Test that verifies if torch.compile works with torchao quantization.""" - quantization_config = TorchAoConfig("int8_weight_only") - components = self.get_dummy_components(quantization_config) - pipe = FluxPipeline(**components) - pipe.to(device=torch_device, dtype=torch.bfloat16) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + quantization_config = TorchAoConfig("int8_weight_only") + components = self.get_dummy_components(quantization_config, model_id=model_id) + pipe = FluxPipeline(**components) + pipe.to(device=torch_device) - inputs = self.get_dummy_inputs(torch_device) - normal_output = pipe(**inputs)[0].flatten()[-32:] + inputs = self.get_dummy_inputs(torch_device) + normal_output = pipe(**inputs)[0].flatten()[-32:] - pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) - inputs = self.get_dummy_inputs(torch_device) - compile_output = pipe(**inputs)[0].flatten()[-32:] + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) + inputs = self.get_dummy_inputs(torch_device) + compile_output = pipe(**inputs)[0].flatten()[-32:] - # Note: Seems to require higher tolerance - self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) + # Note: Seems to require higher tolerance + self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) def test_memory_footprint(self): r""" A simple test to check if the model conversion has been done correctly by checking on the memory footprint of the converted model and the class type of the linear layers of the converted models """ - transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"))["transformer"] - transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32))["transformer"] - transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"] - transformer_bf16 = self.get_dummy_components(None)["transformer"] - - total_int4wo = get_model_size_in_bytes(transformer_int4wo) - total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) - total_int8wo = get_model_size_in_bytes(transformer_int8wo) - total_bf16 = get_model_size_in_bytes(transformer_bf16) - - # Latter has smaller group size, so more groups -> more scales and zero points - self.assertTrue(total_int4wo < total_int4wo_gs32) - # int8 quantizes more layers compare to int4 with default group size - self.assertTrue(total_int8wo < total_int4wo) - # int4wo does not quantize too many layers because of default group size, but for the layers it does - # there is additional overhead of scales and zero points - self.assertTrue(total_bf16 < total_int4wo) + for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: + transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"] + transformer_int4wo_gs32 = self.get_dummy_components( + TorchAoConfig("int4wo", group_size=32), model_id=model_id + )["transformer"] + transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] + transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] + + # Will not quantized all the layers by default due to the model weights shapes not being divisible by group_size=64 + for block in transformer_int4wo.transformer_blocks: + self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor)) + self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor)) + + # Will quantize all the linear layers except x_embedder + for name, module in transformer_int4wo_gs32.named_modules(): + if isinstance(module, nn.Linear) and name not in ["x_embedder"]: + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + # Will quantize all the linear layers + for module in transformer_int8wo.modules(): + if isinstance(module, nn.Linear): + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + total_int4wo = get_model_size_in_bytes(transformer_int4wo) + total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) + total_int8wo = get_model_size_in_bytes(transformer_int8wo) + total_bf16 = get_model_size_in_bytes(transformer_bf16) + + # TODO: refactor to align with other quantization tests + # Latter has smaller group size, so more groups -> more scales and zero points + self.assertTrue(total_int4wo < total_int4wo_gs32) + # int8 quantizes more layers compare to int4 with default group size + self.assertTrue(total_int8wo < total_int4wo) + # int4wo does not quantize too many layers because of default group size, but for the layers it does + # there is additional overhead of scales and zero points + self.assertTrue(total_bf16 < total_int4wo) def test_wrong_config(self): with self.assertRaises(ValueError): @@ -500,6 +530,8 @@ def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, inputs = self.get_dummy_tensor_inputs(torch_device) output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + weight = quantized_model.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): @@ -508,8 +540,8 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, with tempfile.TemporaryDirectory() as tmp_dir: quantized_model.save_pretrained(tmp_dir, safe_serialization=False) loaded_quantized_model = FluxTransformer2DModel.from_pretrained( - tmp_dir, torch_dtype=torch.bfloat16, device_map=torch_device, use_safetensors=False - ) + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ).to(device=torch_device) inputs = self.get_dummy_tensor_inputs(torch_device) output = loaded_quantized_model(**inputs)[0] @@ -563,20 +595,25 @@ def tearDown(self): torch.cuda.empty_cache() def get_dummy_components(self, quantization_config: TorchAoConfig): + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None model_id = "black-forest-labs/FLUX.1-dev" transformer = FluxTransformer2DModel.from_pretrained( model_id, subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + text_encoder = CLIPTextModel.from_pretrained( + model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) text_encoder_2 = T5EncoderModel.from_pretrained( - model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 + model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir ) - tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") - tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") - vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) + tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) + tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) + vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) scheduler = FlowMatchEulerDiscreteScheduler() return { @@ -611,10 +648,12 @@ def _test_quant_type(self, quantization_config, expected_slice): pipe = FluxPipeline(**components) pipe.enable_model_cpu_offload() + weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight + self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) + inputs = self.get_dummy_inputs(torch_device) output = pipe(**inputs)[0].flatten() output_slice = np.concatenate((output[:16], output[-16:])) - self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) def test_quantization(self): @@ -627,7 +666,7 @@ def test_quantization(self): if TorchAoConfig._is_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES_TO_TEST.extend([ ("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), - ("fp5_e3m1", np.array([0.0527, 0.0742, 0.1289, 0.0449, 0.0625, 0.1308, 0.0585, 0.0742, 0.1269, 0.0585, 0.0722, 0.1328, 0.0566, 0.0742, 0.1347, 0.0585, 0.3691, 0.7578, 0.5429, 0.4355, 0.7695, 0.5546, 0.4414, 0.7578, 0.5468, 0.4179, 0.7265, 0.5273, 0.3945, 0.6992, 0.5234, 0.4316])), + ("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), ]) # fmt: on @@ -637,3 +676,125 @@ def test_quantization(self): gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() + + def test_serialization_int8wo(self): + quantization_config = TorchAoConfig("int8wo") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.enable_model_cpu_offload() + + weight = pipe.transformer.x_embedder.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten()[:128] + + with tempfile.TemporaryDirectory() as tmp_dir: + pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False) + pipe.remove_all_hooks() + del pipe.transformer + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + transformer = FluxTransformer2DModel.from_pretrained( + tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False + ) + pipe.transformer = transformer + pipe.enable_model_cpu_offload() + + weight = transformer.x_embedder.weight + self.assertTrue(isinstance(weight, AffineQuantizedTensor)) + + loaded_output = pipe(**inputs)[0].flatten()[:128] + # Seems to require higher tolerance depending on which machine it is being run. + # A difference of 0.06 in normalized pixel space (-1 to 1), corresponds to a difference of + # 0.06 / 2 * 255 = 7.65 in pixel space (0 to 255). On our CI runners, the difference is about 0.04, + # on DGX it is 0.06, and on audace it is 0.037. So, we are using a tolerance of 0.06 here. + self.assertTrue(np.allclose(output, loaded_output, atol=0.06)) + + def test_memory_footprint_int4wo(self): + # The original checkpoints are in bf16 and about 24 GB + expected_memory_in_gb = 6.0 + quantization_config = TorchAoConfig("int4wo") + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 + self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb) + + def test_memory_footprint_int8wo(self): + # The original checkpoints are in bf16 and about 24 GB + expected_memory_in_gb = 12.0 + quantization_config = TorchAoConfig("int8wo") + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, + cache_dir=cache_dir, + ) + int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 + self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb) + + +@require_torch +@require_torch_gpu +@require_torchao_version_greater_or_equal("0.7.0") +@slow +@nightly +class SlowTorchAoPreserializedModelTests(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self, device: torch.device, seed: int = 0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator().manual_seed(seed) + + inputs = { + "prompt": "an astronaut riding a horse in space", + "height": 512, + "width": 512, + "num_inference_steps": 20, + "output_type": "np", + "generator": generator, + } + + return inputs + + def test_transformer_int8wo(self): + # fmt: off + expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703]) + # fmt: on + + # This is just for convenience, so that we can modify it at one place for custom environments and locally testing + cache_dir = None + transformer = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer", + torch_dtype=torch.bfloat16, + use_safetensors=False, + cache_dir=cache_dir, + ) + pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir + ) + pipe.enable_model_cpu_offload() + + # Verify that all linear layer weights are quantized + for name, module in pipe.transformer.named_modules(): + if isinstance(module, nn.Linear): + self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) + + # Verify outputs match expected slice + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0].flatten() + output_slice = np.concatenate((output[:16], output[-16:])) + self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))