Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer state scaling #44

Merged
merged 16 commits into from
Aug 22, 2020
10 changes: 5 additions & 5 deletions benchmarks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def make_model(device, ntokens):
p = Pipe(model, balance)

criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
lr = 0.0005 # learning rate

try:
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.MIXED_PRECISION)
optimizer = Adam(p.parameters(), lr=lr, precision=Precision.PURE_FP16)
except NameError:
optimizer = Adam(p.parameters(), lr=lr)

Expand Down Expand Up @@ -236,10 +236,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,

# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run bytes * 110%
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 193206272 * 1.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did you come up with these number of bytes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same question in one of the previous PRs :D
I guess Jun-Ru printed out the value of torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] and put that number in the check!

assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1605120 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 113801216 * 1.1
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1412608 * 1.1
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 95364608 * 1.1
print("No regression detected")


Expand Down
2 changes: 1 addition & 1 deletion fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <torch/extension.h>

// CUDA forward declaration
void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, float optim_scale, at::Tensor& found_inf, int step, int mode, int bias_correction, float decay);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation.");
Expand Down
74 changes: 59 additions & 15 deletions fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <assert.h>
#include <cmath>
#include "ATen/TensorUtils.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
Expand All @@ -31,6 +30,9 @@ struct AdamFunctor
const float b2,
const float eps,
const float grad_scale,
const bool use_optim_scaling,
const float optim_scale,
andersonic marked this conversation as resolved.
Show resolved Hide resolved
float* found_inf_ptr,
const float step_size,
adamMode_t mode,
const float decay)
Expand Down Expand Up @@ -90,19 +92,43 @@ struct AdamFunctor
int j = i_start + threadIdx.x + ii*blockDim.x;

if(j < n && j < chunk_size) {
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
if (use_optim_scaling) {
// Optimizer state is in half precision and must be scaled
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * (incoming_m[ii] / optim_scale) + (1-b1)*scaled_grad;
float velocity = b2 * (incoming_v[ii] / optim_scale) + (1-b2)*scaled_grad*scaled_grad;

m[j] = static_cast<OPTIM_T>(momentum * optim_scale);
v[j] = static_cast<OPTIM_T>(velocity * optim_scale);

if (!isfinite(m[j]) || !isfinite(v[j])) {
*found_inf_ptr = 1.f;
}

float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
} else {
// Optimizer state is in floating point precision
float scaled_grad = incoming_g[ii]/grad_scale;
float momentum = b1 * incoming_m[ii] + (1-b1)*scaled_grad;
float velocity = b2 * incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
m[j] = static_cast<OPTIM_T>(momentum);
v[j] = static_cast<OPTIM_T>(velocity);
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(velocity + eps);
else // Mode 1
denom = sqrtf(velocity) + eps;
float update = (momentum/denom) + (decay*incoming_p[ii]);
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update));
if (DEPTH == 5) p_copy[j] = (at::Half) p[j];
}
}
}
}
Expand All @@ -118,6 +144,8 @@ void fused_adam_cuda(
float beta2,
float eps,
float grad_scale,
float optim_scale,
at::Tensor& found_inf,
int step,
int mode,
int bias_correction,
Expand All @@ -139,6 +167,9 @@ void fused_adam_cuda(
assert(tl_sz == 4 || tl_sz == 5);
assert(tensor_lists[1][0].scalar_type() == tensor_lists[2][0].scalar_type());

bool use_optim_scaling = (tensor_lists[1][0].scalar_type() == at::ScalarType::Half);
float* found_inf_ptr = found_inf.data_ptr<float>();

if(tl_sz == 5) {
// Mixed precision case
assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float);
Expand All @@ -154,6 +185,9 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
Expand All @@ -174,13 +208,17 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
);
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) {
if(tensor_lists[1][0].scalar_type() == at::ScalarType::Float) {
// FP16 model parameters and gradients; FP32 optimizer state
// Memory-efficient mixed-precision case
// ie FP16 model parameters and gradients; FP32 optimizer state
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
Expand All @@ -191,6 +229,9 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
Expand All @@ -207,6 +248,9 @@ void fused_adam_cuda(
beta2,
eps,
grad_scale,
use_optim_scaling,
optim_scale,
found_inf_ptr,
step_size,
(adamMode_t) mode,
decay
Expand Down
51 changes: 50 additions & 1 deletion fairscale/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ class Precision(Enum):
MEMORY_EFFICIENT_MIXED_PRECISION = auto()
PURE_FP16 = auto()

class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""

def __init__(self, master_tensor: torch.Tensor):
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}

def get(self, device: torch.device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval

class Adam(torch.optim.Optimizer):
state: dict
defaults: dict
Expand Down Expand Up @@ -81,7 +98,9 @@ def __init__(
assert parameters[0].dtype == torch.float16

self.optim_type = torch.float16 if precision is Precision.PURE_FP16 else torch.float32

self._optim_scale = float(2 ** 16) if precision is Precision.PURE_FP16 else 1.0
self._steps_since_optim_scale_change = 0
self._optim_scale_update_freq = 2000 # This is the value that GradScaler uses by default
andersonic marked this conversation as resolved.
Show resolved Hide resolved
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore

if amsgrad:
Expand Down Expand Up @@ -145,8 +164,14 @@ def _step_supports_amp_scaling(self) -> bool:
def mixed_precision(self) -> bool:
return self.precision is Precision.MIXED_PRECISION

def state_dict(self) -> Dict[str, Any]:
d = super().state_dict()
d["optim_scale"] = self._optim_scale
return d

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._optim_scale = state_dict["optim_scale"]

# TODO: Optimizer state gets cast to FP16 and back to FP32 for
# mixed-precision and memory-efficient mixed-precision. Eventually
Expand Down Expand Up @@ -228,6 +253,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
for tl, t in zip(tensorlists[p.device], pl):
tl.append(t)

found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=list(tensorlists.keys())[0])
per_device_found_inf = _MultiDeviceReplicator(found_inf)
andersonic marked this conversation as resolved.
Show resolved Hide resolved

for tensordevice, tensorlist in tensorlists.items():
with torch.cuda.device(tensordevice):
fused_adam_cuda.adam(
Expand All @@ -239,12 +267,33 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
beta2,
group["eps"],
scale,
self._optim_scale,
per_device_found_inf.get(tensordevice),
state["step"],
self.eps_mode,
bias_correction,
group["weight_decay"],
)

if sum(v.item() for v in per_device_found_inf._per_device_tensors.values()):
self._steps_since_optim_scale_change = 0
self._optim_scale /= 2

if self._optim_scale < 1.0:
raise RuntimeError("Optimizer state scale < 1. This may mean that gradients are exploding")

for group in self.param_groups:
for p in group["params"]:
self.state[p]["exp_avg"] = torch.zeros_like(p, dtype=self.optim_type)
self.state[p]["exp_avg_sq"] = torch.zeros_like(p, dtype=self.optim_type)
else:
self._steps_since_optim_scale_change += 1

if self._steps_since_optim_scale_change == self._optim_scale_update_freq:
self._steps_since_optim_scale_change = 0
if self._optim_scale < 2 ** 16:
self._optim_scale *= 2

return loss


Expand Down
32 changes: 32 additions & 0 deletions tests/optim/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,38 @@ def test_state_dict_pure_fp16():
state_dict_test(optimizer, weight, bias, input)


@skip_if_no_cuda
@skip_if_no_adam
def test_update_optim_scale():
weight, bias, input = make_half_precision_params()
optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
optimizer._optim_scale_update_freq = 1
optimizer._optim_scale = 2 ** 15

optimizer.zero_grad()
loss = (weight.mv(input) + bias).pow(2).sum()
loss.backward()
optimizer.step()

assert optimizer._optim_scale == 2 ** 16


@skip_if_no_cuda
@skip_if_no_adam
def test_exploding_optimizer_state():
weight = torch.tensor([[float("inf")]]).half().cuda().requires_grad_()
input = torch.tensor([1.0]).half().cuda().requires_grad_()

optimizer = Adam([weight], lr=1e-3, precision=Precision.PURE_FP16)
optimizer._optim_scale = 1.0

optimizer.zero_grad()
loss = (weight.mv(input)).pow(2).sum()
loss.backward()
with pytest.raises(RuntimeError):
optimizer.step()


@skip_if_no_cuda
@skip_if_no_adam
def test_build_fp32_params():
Expand Down