9
9
10
10
from typing import Any , Optional , Tuple
11
11
12
- import float8_experimental .config as config
13
-
14
12
import torch
15
- import torch .nn as nn
16
13
import torch .utils ._pytree as pytree
17
14
18
15
from float8_experimental .float8_tensor import (
@@ -85,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw(
85
82
86
83
class WeightWithDynamicFloat8CastTensor (torch .Tensor ):
87
84
@staticmethod
88
- def __new__ (cls , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
85
+ def __new__ (
86
+ cls ,
87
+ tensor : torch .Tensor ,
88
+ mm_config : ScaledMMConfig ,
89
+ precomputed_scale : Optional [torch .Tensor ] = None ,
90
+ ):
89
91
return torch .Tensor ._make_wrapper_subclass (
90
92
cls ,
91
93
tensor .size (),
@@ -99,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
99
101
requires_grad = tensor .requires_grad ,
100
102
)
101
103
102
- def __init__ (self , tensor : torch .Tensor , mm_config : ScaledMMConfig ):
104
+ def __init__ (
105
+ self ,
106
+ tensor : torch .Tensor ,
107
+ mm_config : ScaledMMConfig ,
108
+ precomputed_scale : Optional [torch .Tensor ] = None ,
109
+ ):
103
110
self ._tensor = tensor
104
111
self ._mm_config = mm_config
112
+ # for dynamic scaling
113
+ # `precompute_float8_dynamic_scale_for_fsdp` calculates scales
114
+ # for all float8 parameters after optimizer step
115
+ self ._precomputed_scale = precomputed_scale
105
116
106
117
@classmethod
107
118
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
@@ -130,20 +141,35 @@ def unwrap(t):
130
141
)
131
142
132
143
def __tensor_flatten__ (self ):
133
- return ["_tensor" ], self ._mm_config
144
+ if self ._precomputed_scale :
145
+ return ["_tensor" , "_precomputed_scale" ], self ._mm_config
146
+ else :
147
+ return ["_tensor" ], self ._mm_config
134
148
135
149
@staticmethod
136
150
def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
137
151
mm_config = flatten_spec
138
- return WeightWithDynamicFloat8CastTensor (inner_tensors ["_tensor" ], mm_config )
152
+ return WeightWithDynamicFloat8CastTensor (
153
+ inner_tensors ["_tensor" ],
154
+ mm_config ,
155
+ getattr (inner_tensors , "_precomputed_scale" , None ),
156
+ )
139
157
140
158
def __repr__ (self ):
141
159
return f"WeightWithDynamicFloat8CastTensor(tensor={ self ._tensor } , mm_config={ self ._mm_config } )"
142
160
143
161
def fsdp_pre_all_gather (self , mesh ):
144
- float8_tensor = cast_to_float8_e4m3_dynamic (
145
- self ._tensor , self ._mm_config , reduce_amax = True
146
- )
162
+ if self ._precomputed_scale is not None :
163
+ float8_tensor = Float8Tensor .to_float8 (
164
+ self ._tensor ,
165
+ self ._precomputed_scale ,
166
+ torch .float8_e4m3fn ,
167
+ mm_config = self ._mm_config ,
168
+ )
169
+ else :
170
+ float8_tensor = cast_to_float8_e4m3_dynamic (
171
+ self ._tensor , self ._mm_config , reduce_amax = True
172
+ )
147
173
return (float8_tensor ._data ,), (float8_tensor ._scale ,)
148
174
149
175
def fsdp_post_all_gather (
0 commit comments