Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[3/x]: simplify FSDP1 test and add coverage for dynamic scaling
Browse files Browse the repository at this point in the history
Summary:

1. simplify the FSDP test, instead of testing 1 GPU vs N GPUs, instead
   hold the number of GPUs constant and test bf16 vs float8. Remove
   various technical debt that accumulated in this test.
2. add testing for dynamic scaling of weights

Test Plan:

```
./test/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: fcd6368857a35e2ee227f58d389512682d85e409
Pull Request resolved: #293
  • Loading branch information
vkuzo committed Jul 1, 2024
1 parent a27b2fc commit a0ad964
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 177 deletions.
251 changes: 95 additions & 156 deletions test/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Test numerics of single GPU vs FSDP of toy model. At a high level:
1. start with reference input and state dict for a single GPU model
2. run fw+bw+optim on single GPU, save the results
3. run fw+bw+optim with FSDP, save the results
4. verify that the outputs and state dicts after optim update match
later 1-4 can be repeated for fp16, various combinations of fp8, etc.
Test numerics of bf16 versus float8 with FSDP on. At a high level:
1. start with a reference model, with FSDP on
2. run forward + backward + optim for 2 iterations
3. repeat 2 with float8 enabled (2 iterations needed for delayed scaling)
4. compare outputs and state dict between (2) and (3), should be close
"""

import copy
import os
import warnings

Expand All @@ -22,11 +21,12 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_utils import compute_error
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
Expand All @@ -35,18 +35,9 @@

torch.manual_seed(0)

# assumes user is running the script from /data/users/{user}/float8_experimental
data_dir = os.path.join(os.path.dirname(__file__), "tmp")
input_fname = os.path.join(data_dir, "input.pt")
sd_in_fname = os.path.join(data_dir, "sd_in.pt")
sd_out_single_gpu_fname = os.path.join(data_dir, "sd_out_single_gpu.pt")
sd_out_fsdp_fname = os.path.join(data_dir, "sd_out_fsdp.pt")
output_single_gpu_fname = os.path.join(data_dir, "output_single_gpu.pt")
output_fsdp_fname = os.path.join(data_dir, "output_fsdp.pt")

B, M, K, N = 8, 8, 32, 32
lr = 0.01
N_ITER = 5
N_ITER = 2


def setup(rank, world_size):
Expand All @@ -61,15 +52,13 @@ def cleanup():
dist.destroy_process_group()


def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
def get_model(K, N, base_dtype=torch.float32):
m = nn.Sequential(
nn.Linear(K, N, dtype=base_dtype),
nn.ReLU(),
nn.Linear(N, N, dtype=base_dtype),
nn.ReLU(),
)
if is_fp8:
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
return m


Expand All @@ -79,52 +68,84 @@ def fsdp_main(rank, world_size, args):
setup(rank, world_size)
torch.cuda.set_device(rank)

# TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
# We can investigate and fix it later.
is_fp8, emulate, base_dtype, compile, fullgraph = args
model = get_model(K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype).to(
rank
emulate, base_dtype, compile, use_weight_dynamic_scaling = args
model = get_model(K, N, base_dtype=base_dtype).to(rank)
model_fp8 = copy.deepcopy(model)

scaling_type_w = (
TensorScalingType.DYNAMIC
if use_weight_dynamic_scaling
else TensorScalingType.DELAYED
)

# Note: we only iterate over `scaling_type_w` because FSDP only interacts
# with weights.
swap_linear_with_float8_linear(
model_fp8,
Float8Linear,
emulate=False,
scaling_type_w=scaling_type_w,
)
model.load_state_dict(torch.load(sd_in_fname, weights_only=True))

# To compile FSDP, we need use_orig_params to True
model = FSDP(model, use_orig_params=True)
model_fp8 = FSDP(model_fp8, use_orig_params=True)
# TODO: The following line doesn't work. We should fix it.
# model = FSDP(torch.compile(model), use_orig_params=True)

# Note: we need to multiply by world_size here to match single GPU
# optimizer update
optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr)

ref_input_global = torch.load(input_fname, weights_only=True).to(base_dtype)
# Note: we need two different inputs to properly measure the impact of
# delayed scaling, before the first input uses dynamic scaling to
# populate the buffers
ref_input_global = [
torch.randn(B, M, K).cuda().to(base_dtype),
torch.randn(B, M, K).cuda().to(base_dtype),
]
ref_grad_global = [
torch.randn(B, M, N).cuda().to(base_dtype),
torch.randn(B, M, N).cuda().to(base_dtype),
]
ref_input_local = []
ref_grad_local = []

# basic distributed data sampling
assert B % world_size == 0
bsz_local_start = int(rank / world_size * B)
bsz_local_end = int((rank + 1) / world_size * B)
ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank)
for idx in range(N_ITER):
ref_input_local.append(
ref_input_global[idx][bsz_local_start:bsz_local_end].to(rank)
)
ref_grad_local.append(
ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank)
)

sync_float8_func = sync_float8_amax_and_scale_history
if compile:
sync_float8_func = torch.compile(
sync_float8_amax_and_scale_history, fullgraph=fullgraph
)

def forward_backward(model):
optimizer.zero_grad()
y_local = model(ref_input_local)
y_local.sum().backward()
sync_float8_func(model)
optimizer.step()
sync_float8_func = torch.compile(sync_float8_amax_and_scale_history)

def forward_backward(model, optim, is_fp8, i):
optim.zero_grad()
y_local = model(ref_input_local[i])
y_local.backward(ref_grad_local[i])
if is_fp8:
sync_float8_func(model)
optim.step()
return y_local

for iter in range(N_ITER):
for i in range(N_ITER):
# We first run one iteration without compile, as a workaround to compile float8 layer.
# In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False"
# After that, float8 layers go the the branches of "self.is_amax_initialized == True"
# TODO: Need to fix compile to run wihtout this workaround.
if iter == 1 and compile:
model = torch.compile(model, fullgraph=fullgraph)
y_local = forward_backward(model)
if i == 1 and compile:
model = torch.compile(model)
model_fp8 = torch.compile(model_fp8)
y_local = forward_backward(model, optimizer, is_fp8=False, i=i)
y_local_fp8 = forward_backward(model_fp8, optimizer_fp8, is_fp8=True, i=i)
local_sqnr = compute_error(y_local, y_local_fp8)

# get global y
y_global = [
Expand All @@ -133,132 +154,50 @@ def forward_backward(model):
]
dist.all_gather(y_global, y_local)
y_global = torch.cat(y_global, dim=0)
y_global_fp8 = [
torch.zeros(*y_local_fp8.shape, dtype=base_dtype).to(rank)
for r in range(world_size)
]
dist.all_gather(y_global_fp8, y_local_fp8)
y_global_fp8 = torch.cat(y_global_fp8, dim=0)
if rank == 0:
torch.save(y_global, output_fsdp_fname)
sqnr = compute_error(y_global, y_global_fp8)
assert sqnr > 15.0, f"SQNR of {sqnr} is too low"

# get global state dict
# https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html
dist.barrier()
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
with FSDP.state_dict_type(model_fp8, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state_fp8 = model_fp8.state_dict()
if rank == 0:
torch.save(cpu_state, sd_out_fsdp_fname)
for k, v1 in cpu_state.items():
v2 = cpu_state_fp8[k]
v1, v2 = v1.cpu(), v2.cpu()
sqnr = compute_error(v1, v2)
assert sqnr > 15.0, f"SQNR of {sqnr} is too low, k: {k}, v1: {v1}, v2: {v2}"

cleanup()


def run(mode: str, is_fp8: bool, compile_fsdp: bool = False, fullgraph: bool = False):
print(f"Mode: {mode}".center(100, "-"))
def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False):
base_dtype = torch.bfloat16
if not os.path.exists(data_dir):
os.makedirs(data_dir)

emulate = False
if is_fp8:
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode")
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
)
emulate = True

if mode == "generate":
# generate reference input
ref_input = torch.randn(B, M, K).cuda().to(base_dtype)
model = get_model(
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
).cuda()
torch.save(ref_input, input_fname)
torch.save(model.state_dict(), sd_in_fname)

elif mode == "single_gpu":
ref_input = torch.load(input_fname, weights_only=True).to(base_dtype)
model = get_model(
K, N, is_fp8=is_fp8, emulate=emulate, base_dtype=base_dtype
).cuda()
model.load_state_dict(torch.load(sd_in_fname, weights_only=True))
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

def forward_backward():
optimizer.zero_grad()
y = model(ref_input)
y.sum().backward()
sync_float8_amax_and_scale_history(model)
optimizer.step()
return y

for _ in range(N_ITER):
y = forward_backward()

torch.save(y, output_single_gpu_fname)
torch.save(model.state_dict(), sd_out_single_gpu_fname)

elif mode == "fsdp":
WORLD_SIZE = torch.cuda.device_count()
# We only compile for fsdp, and compare the numerics with signle-gpu no-compile
args = (is_fp8, emulate, base_dtype, compile_fsdp, fullgraph)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

elif mode == "analyze":
y_single_gpu = torch.load(output_single_gpu_fname, weights_only=True).cpu()
y_fsdp = torch.load(output_fsdp_fname, weights_only=True).cpu()
if is_fp8 and not emulate:
atol, rtol = 2e-2, 2e-2
else:
atol, rtol = None, None
torch.testing.assert_close(y_single_gpu, y_fsdp, atol=atol, rtol=rtol)
print("output testing single_gpu vs FSDP success")

sd_out_single_gpu = torch.load(sd_out_single_gpu_fname, weights_only=True)
sd_out_fsdp = torch.load(sd_out_fsdp_fname, weights_only=True)
for k, v1 in sd_out_single_gpu.items():
if compile_fsdp:
# The state-dict for compiled fsdp has a `_orig_mod` prefix
k = f"_orig_mod.{k}"
v2 = sd_out_fsdp[k]
v1, v2 = v1.cpu(), v2.cpu()
if is_fp8 and "noop" in k:
# Note: for fp8 single-node vs FSDP, we are not expected
# to match the scale of the gradients which follow the following
# pattern:
#
# `op(g_prev, out_scale) -> g_fp8 -> cast -> g_fp16 -> reduce`.
#
# Reasoning is the order of operations of calculating the above:
# a. single node:
# 1. calculate dL_dValue and s_dL_dValue
# 2. you're done
# b. FSDP:
# 1. calculate dL_dValue and s_dL_dValue of each slice
# 2. reduce using summation
#
# a and b cannot always match because calculating the scale
# involves taking max(dL_dW), FSDP reduces the gradients, and
# max(abs(a), abs(b)) != max(abs(a + b))
#
# In today's codebase, we do not hit this yet. We expect to hit
# this if we implement TP with activation gradients that both need
# reductions and need fp8 distributed comms. Solution - TBD.

# noop buffers are unused, so ok for them to not match
pass
else:
try:
if v1.dtype == torch.bfloat16 and not emulate:
atol, rtol = 2e-2, 2e-2
else:
if k == "1.fp8_amax_history_x" and not emulate:
atol, rtol = 2e-2, 6e-3
else:
atol, rtol = None, None
torch.testing.assert_close(v1, v2, atol=atol, rtol=rtol)
except Exception as e:
print("debug:", k, v1, v2)
raise e
print("state dict testing single_gpu vs FSDP success")
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode")
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode"
)
emulate = True

WORLD_SIZE = torch.cuda.device_count()
args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)


if __name__ == "__main__":
Expand Down
27 changes: 6 additions & 21 deletions test/test_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,12 @@
set -e

launch() {
echo "launching IS_FP8 $IS_FP8, compile_fsdp $COMPILE, fullgraph $FULLGRAPH"
echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING"

# generate the test data
python test/test_fsdp.py --mode generate --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
echo "Success: ✅"

# generate single GPU model output and updated state dict
python test/test_fsdp.py --mode single_gpu --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
echo "Success: ✅"

# generate FSDP model output and updated state dict
# the NCCL_DEBUG setting is to avoid log spew
# the CUDA_VISIBLE_DEVICES setting is for easy debugging
# the NCCL_NET setting is to work around transient issues on a
# specific host (`devgpu001.nha2`)
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 NCCL_NET=SOCKET python test/test_fsdp.py \
--mode fsdp --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH

# compare the outputs and state dicts and verify equivalence
python test/test_fsdp.py --mode analyze --is_fp8 $IS_FP8 --compile_fsdp $COMPILE --fullgraph $FULLGRAPH
echo "Success: ✅"
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp.py \
--compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING

echo "✅ All Tests Passed ✅"
}
Expand All @@ -34,10 +19,10 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False";
exit
fi

# IS_FP8, COMPILE, FULLGRAPH
for i in False,False,False True,False,False True,True,False
# COMPILE, USE_WEIGHT_DYNAMIC_SCALING
for i in False,False False,True True,False True,True
do
IFS=","; set -- $i;
IS_FP8=$1; COMPILE=$2; FULLGRAPH=$3
COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2
launch
done

0 comments on commit a0ad964

Please sign in to comment.