Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(QA norm):check norm weights for different ranks #62

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions tests/common_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
import random
import socket

import numpy as np
import torch

import internlm
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import Config
from internlm.data.utils import unpack_data
from internlm.initialize.launch import args_sanity_check

config = Config(
dict(
parallel=dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
),
data=dict(
seq_len=2048,
micro_num=4,
micro_bsz=2,
pack_sample_into_one=False,
min_length=50,
total_steps=10,
valid_micro_num=4,
valid_every=300,
rampup_batch_size=None,
diag_outlier_ratio=1.1,
train_folder=None,
valid_folder=None,
),
model=dict(
checkpoint=False,
num_attention_heads=32,
embed_split_hidden=True,
vocab_size=103168,
embed_grad_scale=1,
parallel_output=True,
hidden_size=4096,
num_layers=32,
mlp_ratio=8 / 3,
apply_post_layer_norm=False,
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1,
),
model_type="INTERNLM",
alert_address=None,
monitor=dict(alert=dict(enable_feishu_alert=False, feishu_alert_address=None, light_monitor_address=None)),
grad_scaler=dict(
fp16=dict(
initial_scale=2**16,
min_scale=1,
growth_interval=1000,
),
growth_factor=2,
backoff_factor=0.5,
max_scale=2**24,
hysteresis=2,
),
adam=dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
),
hybrid_zero_optimizer=dict(
overlap_sync_grad=True,
overlap_sync_param=False,
reduce_bucket_size=512 * 1024 * 1024,
clip_grad_norm=1.0,
),
beta2_scheduler=dict(
init_beta2=0.95,
c=0,
cur_iter=-1,
),
lr_scheduler=dict(
total_steps=10,
init_steps=0,
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
),
ckpt=dict(
enable_save_ckpt=False,
auto_resume=False,
),
loss=dict(
label_smoothing=0,
),
)
)


def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def build_environment(rank, world_size, free_port, config):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(free_port)
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
args_sanity_check()


def seed_all(seed, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cuda_deterministic: # slower, more reproducible
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True


def load_new_batch(train_dl, train_iter):
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_dl)
batch = next(train_iter)

if batch[0].get("type_ids", None) is not None:
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True)

return batch, train_iter
205 changes: 205 additions & 0 deletions tests/test_training/test_norm_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import gc
import multiprocessing as mp
import os

import pytest
import torch

import internlm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.train import (
get_scheduler_hooks,
get_train_data_loader,
initialize_isp_communicator,
initialize_model,
initialize_optimizer,
)
from internlm.utils.logger import get_logger
from tests.common_fixture import (
build_environment,
config,
find_free_port,
load_new_batch,
seed_all,
)

logger = get_logger(__file__)


def compute_rotol(tensor1, tensor2):
torch.set_printoptions(precision=10)
max_diff, index_max_diff = (tensor1 - tensor2).abs().max(dim=0)
max_diff = max_diff.item()
index_max_diff = index_max_diff.item()
rtol = max_diff / abs(tensor2[index_max_diff])
logger.info(
f"The max diff between two tensors is {max_diff}, which is the diff between element "
f"{tensor1[index_max_diff]} and {tensor2[index_max_diff]}. The relative diff is {rtol}."
)


def check_norm_pos(name, norm_list):
for i in range(7):
for j in range(len(norm_list[i])):
if not torch.equal(norm_list[i][j], norm_list[i + 1][j]):
compute_rotol(norm_list[i][j], norm_list[i + 1][j])
assert False, f"The {name} weights of block between different ranks are not equal."


def train_check_norm_weight(args):
# init
rank, world_size, free_port, sp = args
total_steps = 2000
share_data_path = os.environ["share_data_path"]
config.data.total_steps = total_steps
config.lr_scheduler.total_steps = total_steps
config.parallel.tensor = dict(size=2, mode=f"{sp}")
if sp == "isp":
config.parallel.weight = dict(size=4, overlap=True, memory_pool=True)
config.data.train_folder = os.path.join(share_data_path, "quality_assurance/0715_data/train")

build_environment(rank, world_size, free_port, config)

# set seed
seed_all(1024)

# initialize model
model = initialize_model()

# initialize isp communicator
isp_communicator = initialize_isp_communicator(model)

# initialize loss function
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing)

optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator)

train_dl, dataset_types = get_train_data_loader(num_worker=0)

metric = AccPerplex(
device=torch.cuda.current_device(),
tp_pg=gpc.get_group(ParallelMode.TENSOR),
dp_pg=gpc.get_group(ParallelMode.DATA),
dataset_types=dataset_types,
)

trainer, train_dl, _, _ = internlm.initialize_trainer(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dl,
lr_scheduler=lr_scheduler,
beta2_scheduler=beta2_scheduler,
scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator),
)

# transfer the train data loader into train data iterator
trainer.train()

train_iter = iter(train_dl)

for batch_count in range(total_steps):
if batch_count % 100 == 0:
torch.cuda.empty_cache()
gc.collect()

# load batch data
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter)

# process data
if batch[0].get("type_ids", None) is not None:
metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None))

# zero the grads of parameters
_, _, _ = trainer.execute_schedule(
batch,
forward_only=False,
return_loss=True,
return_output_label=False,
)

if isp_communicator and isp_communicator.enable_memory_pool:
isp_communicator.memory_pool.reset_lazy_pools()

trainer.step()

torch.cuda.reset_peak_memory_stats()

blocks_norm1_list = []
blocks_norm2_list = []

for block in model.model.blocks:
blocks_norm1_list.append(block.norm1.weight.detach().to("cpu"))
blocks_norm2_list.append(block.norm2.weight.detach().to("cpu"))
if hasattr(model.model, "norm"):
model_norm = model.model.norm.weight.detach().to("cpu")
else:
model_norm = None

return blocks_norm1_list, blocks_norm2_list, model_norm


def check_result(result):
norm1_ranks = []
norm2_ranks = []
model_norm_ranks = []
for rank in range(8):
norm1_ranks.append(result[rank][0])
norm2_ranks.append(result[rank][1])
if result[rank][2] is not None:
model_norm_ranks.append(result[rank][2])

check_norm_pos("norm1", norm1_ranks)
check_norm_pos("norm2", norm2_ranks)
for i in range(len(model_norm_ranks) - 1):
if not torch.equal(model_norm_ranks[i], model_norm_ranks[i + 1]):
compute_rotol(model_norm_ranks[i], model_norm_ranks[i + 1])
assert False, "The norm weights of model between different ranks are not equal."


@pytest.mark.check_norm_msp
def test_check_norm_msp():
free_port = find_free_port()
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
result = pool.map(
train_check_norm_weight,
[[rank, 8, free_port, "msp"] for rank in range(8)],
)
pool.close()
pool.join()

check_result(result)


@pytest.mark.check_norm_fsp
def test_check_norm_fsp():
free_port = find_free_port()
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
result = pool.map(
train_check_norm_weight,
[[rank, 8, free_port, "fsp"] for rank in range(8)],
)
pool.close()
pool.join()

check_result(result)


@pytest.mark.check_norm_isp
def test_check_norm_isp():
free_port = find_free_port()
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
result = pool.map(
train_check_norm_weight,
[[rank, 8, free_port, "isp"] for rank in range(8)],
)
pool.close()
pool.join()

check_result(result)
Loading