diff --git a/README.md b/README.md
index 4170bab4..35930f6d 100644
--- a/README.md
+++ b/README.md
@@ -40,3 +40,30 @@ tensorboard --logdir=./torchtrain/outputs/tb
```
4. go to the URL it provides OR to http://localhost:6006/
+
+## Multi-Node Training
+For training on ParallelCluster/Slurm type configurations, you can use the multinode_trainer.slurm file to submit your sbatch job.
+Note that you will need to adjust the number of nodes and gpu count to your cluster configs.
+To adjust total nodes:
+```
+#SBATCH --ntasks=2
+#SBATCH --nodes=2
+```
+should both be set to your total node count.
+Then update the srun launch parameters to match:
+```
+srun torchrun --nnodes 2
+```
+where nnodes is your total node count, matching the sbatch node count above.
+
+To adjust gpu count per node:
+
+If your gpu count per node is not 8, adjust:
+
+```--nproc_per_node```
+
+ in the torchrun command and
+
+```#SBATCH --gpus-per-task```
+
+in the SBATCH command section.
diff --git a/multinode_trainer.slurm b/multinode_trainer.slurm
new file mode 100644
index 00000000..296bb625
--- /dev/null
+++ b/multinode_trainer.slurm
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# --- This script is optimized for AWS with EFA
+# --- adjust NCCL_BUFFSIZE if you encounter memory
+# --- constraint issues or to tune for improved performance.
+# ---
+
+#SBATCH --job-name=torchtrain_multi_node
+
+#SBATCH --ntasks=2
+
+#SBATCH --nodes=2
+
+#SBATCH --gpus-per-task=8
+
+#SBATCH --cpus-per-task=96
+
+#SBATCH --partition=train
+
+
+nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
+nodes_array=($nodes)
+head_node=${nodes_array[0]}
+head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
+
+echo Node IP: $head_node_ip
+export LOGLEVEL=INFO
+# Enable for A100
+export FI_PROVIDER="efa"
+# Ensure that P2P is available
+# export NCCL_P2P_DISABLE=1
+export NCCL_IB_DISABLE=1
+
+# debugging flags (optional)
+export NCCL_DEBUG=WARN
+export PYTHONFAULTHANDLER=1
+# optional debug settings
+# export NCCL_DEBUG=INFO
+# NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV
+
+export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH
+export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH
+export CUDA_LAUNCH_BLOCKING=0
+
+# on your cluster you might need these:
+# set the network interface
+export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
+export NCCL_BUFFSIZE=2097152
+#export TORCH_DIST_INIT_BARRIER=1
+export FI_EFA_SET_CUDA_SYNC_MEMOPS=0
+
+dcgmi profile --pause
+# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below
+# to your specific node count, and update target launch file.
+srun torchrun --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./train.py --steps 10
+dcgmi profile --resume
diff --git a/torchtrain/profiling.py b/torchtrain/profiling.py
index 0a128c46..167067e0 100644
--- a/torchtrain/profiling.py
+++ b/torchtrain/profiling.py
@@ -1,8 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import contextlib
import os
+
import torch
try:
@@ -47,7 +48,7 @@ def trace_handler(prof):
curr_trace_dir_name = "iteration_" + str(_global_iter_count)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
- os.makedirs(curr_trace_dir)
+ os.makedirs(curr_trace_dir, exist_ok=True)
rank0_log(f"exporting profile traces to {curr_trace_dir}")
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
@@ -55,7 +56,7 @@ def trace_handler(prof):
rank0_log(f"Profiling active. Traces will be saved at {trace_dir}")
if not os.path.exists(trace_dir):
- os.makedirs(trace_dir)
+ os.makedirs(trace_dir, exist_ok=True)
with torch.profiler.profile(
activities=[