Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 23, 2024
1 parent 3cb5af6 commit ee34110
Show file tree
Hide file tree
Showing 11 changed files with 327 additions and 324 deletions.
59 changes: 33 additions & 26 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def _parse_args(argv=None, namespace=None):
"-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enable FP8 GEMM."
)
parser.add_argument("--fp8", action="store_true", default=False, help="Enable FP8 GEMM.")
parser.add_argument(
"--p2p", action="store_true", default=False, help="Test overlap with P2P comms."
)
Expand All @@ -78,7 +76,7 @@ def _parse_args(argv=None, namespace=None):
"--bulk-overlap",
action="store_true",
default=False,
help="Enable bulk AG or RS overlap for a tensor that is not involved in the GEMM compute."
help="Enable bulk AG or RS overlap for a tensor that is not involved in the GEMM compute.",
)
parser.add_argument(
"--check-numerics",
Expand Down Expand Up @@ -111,7 +109,7 @@ def _parse_args(argv=None, namespace=None):
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore."
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--init-method", type=str, default=None, help="Set the torch.distributed init method."
Expand All @@ -120,7 +118,9 @@ def _parse_args(argv=None, namespace=None):
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with 'device_id' argument to bind each rank to 1 device."
help=(
"Initialize torch.distributed with 'device_id' argument to bind each rank to 1 device."
),
)
parser.add_argument(
"--bootstrap-backend",
Expand All @@ -132,7 +132,11 @@ def _parse_args(argv=None, namespace=None):
+ "initialization."
),
)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False,)
parser.add_argument(
"--use-cuda-graphs",
action="store_true",
default=False,
)
parser.add_argument("-v", "--verbose", action="store_true", default=False)
opts = parser.parse_args(argv, namespace)

Expand All @@ -159,6 +163,7 @@ def _parse_args(argv=None, namespace=None):

return opts


@record
def _main(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
Expand Down Expand Up @@ -258,8 +263,12 @@ def dist_print(msg, src=None, info=False, section=False, group=None):
elif opts.bootstrap_backend == "mpi":
assert dist.is_mpi_available()
bootstrap_pg = dist.new_group(backend=opts.bootstrap_backend)
dist_print(f"Bootstrapping comm+GEMM overlap with backend=\"{opts.bootstrap_backend}\"", src=0,
info=True, section=True)
dist_print(
f'Bootstrapping comm+GEMM overlap with backend="{opts.bootstrap_backend}"',
src=0,
info=True,
section=True,
)

# torch.distributed callback wrappers for bootstrapping userbuffers
def allgather_callback(global_data: torch.Tensor, local_data: torch.Tensor, group: str):
Expand Down Expand Up @@ -444,8 +453,7 @@ def barrier_callback(group: str):
inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0]
if ub_obj2 is not None:
ker2_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel2_t, 0, 1),
tp_group
torch.transpose(kernel2_t, 0, 1), tp_group
)[0]
else:
# RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N)
Expand All @@ -462,7 +470,7 @@ def barrier_callback(group: str):
ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0]
else:
# First all-gather all the bulk inputs into a list
bulk_inp_list = [ torch.zeros_like(bulk_inp) for _ in range(LOCAL_SIZE) ]
bulk_inp_list = [torch.zeros_like(bulk_inp) for _ in range(LOCAL_SIZE)]
dist.all_gather(bulk_inp_list, bulk_inp, tp_group)
# Sum the list together for final global result
ref_g = torch.stack(bulk_inp_list).sum(dim=0)
Expand All @@ -482,8 +490,7 @@ def barrier_callback(group: str):
fp8_dtype = tex.DType.kFloat8E4M3
fp8_meta = tex.FP8TensorMeta()
num_gemms = 6 if ub_obj2 is not None else 3
fp8_meta.amax_history = torch.zeros((2, num_gemms),
dtype=torch.float, device="cuda")
fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda")
fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda")
fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda")

Expand Down Expand Up @@ -537,7 +544,7 @@ def barrier_callback(group: str):
kernel2_t.to(dtype=torch.float32),
kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT],
rtol=0.125,
atol=0.0675
atol=0.0675,
)

# Set Fp8 scales for userbuffers
Expand Down Expand Up @@ -624,7 +631,7 @@ def _fp8_gemm2(gemm1_out):
ub_algo=tex.NVTE_Comm_Overlap_Algo.ATOMIC_GEMM_RS_P2P,
ub=ub_obj2,
extra_output_tensor=rs_out2,
out=ubuf_out2
out=ubuf_out2,
)

def _gemm():
Expand Down Expand Up @@ -743,8 +750,12 @@ def _gemm():
# AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
output = all_outputs[0]
test_out = torch.transpose(
te.distributed.gather_along_first_dim(torch.transpose(output, 0, 1),
tp_group)[0], 0, 1)
te.distributed.gather_along_first_dim(
torch.transpose(output, 0, 1), tp_group
)[0],
0,
1,
)
else:
# RS Output: (M/P, N) -> gather -> (M, N)
output = rs_out
Expand Down Expand Up @@ -774,29 +785,25 @@ def _gemm():
ref_out = ref2_g if ub_obj2 is not None else ref_g
ref_nonzeros = torch.count_nonzero(ref_out)
nonzero_info = (
f"output nonzeros = {test_nonzeros} "
+ f"| reference count = {ref_nonzeros}"
f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}"
)
dist_print(nonzero_info, src=0, section=True)

sizing_info = (
f"input: {list(inp.shape)} "
+ f"| GEMM1 weights: {list(kernel_t.shape)[::-1]} "
f"input: {list(inp.shape)} " + f"| GEMM1 weights: {list(kernel_t.shape)[::-1]} "
)
if ub_obj2 is not None:
sizing_info += f"| GEMM2 weights: {list(kernel2_t.shape)[::-1]} "
sizing_info += f"| output: {list(output.shape)}\n"
dist_print(sizing_info, section=True)

sizing_info_g = (
f"input: {list(inp_g.shape)} "
+ f"| GEMM1 weights: {list(ker_g.shape)} "
f"input: {list(inp_g.shape)} " + f"| GEMM1 weights: {list(ker_g.shape)} "
)
if ub_obj2 is not None:
sizing_info_g += f"| GEMM2 weights: {list(ker2_g.shape)} "
sizing_info_g += (
f"| output: {list(test_out.shape)} "
+ f"| reference: {list(ref_out.shape)}\n"
f"| output: {list(test_out.shape)} " + f"| reference: {list(ref_out.shape)}\n"
)
dist_print(sizing_info_g, src=0)

Expand Down
14 changes: 5 additions & 9 deletions tests/pytorch/distributed/run_transformer_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def parse_args(argv=None, namespace=None):
)
return parser.parse_args(argv, namespace)


@record
def train(opts):
WORLD_RANK = int(os.getenv("RANK"))
Expand Down Expand Up @@ -133,12 +134,7 @@ def dist_print(msg, src=None, debug=False, section=False):
"ub_tp_comm_overlap": True,
"parallel_attention_mlp": False,
}
te_gpt = te.TransformerLayer(
hidden_size,
4 * hidden_size,
opts.num_heads,
**te_kwargs
)
te_gpt = te.TransformerLayer(hidden_size, 4 * hidden_size, opts.num_heads, **te_kwargs)

# Create new TransformerLayer without comm overlap
te_kwargs["ub_tp_comm_overlap"] = False
Expand Down Expand Up @@ -244,11 +240,11 @@ def dist_print(msg, src=None, debug=False, section=False):
dist.all_reduce(numerics_failed_tensor, dist.ReduceOp.MAX)
numerics_failed = bool(numerics_failed_tensor[0].item())
if not numerics_failed:
max_diff_all = [ None for _ in range(WORLD_SIZE) ]
max_diff_all = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(max_diff_all, max_diff)
max_diff_idx_all = [ None for _ in range(WORLD_SIZE) ]
max_diff_idx_all = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(max_diff_idx_all, max_diff_idx)
max_diff_name_all = [ None for _ in range(WORLD_SIZE) ]
max_diff_name_all = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(max_diff_name_all, max_diff_name)
max_diff = max(max_diff_all)
diff_idx = max_diff_all.index(max_diff)
Expand Down
Loading

0 comments on commit ee34110

Please sign in to comment.