Skip to content

Commit

Permalink
Fix QA port: add free_port (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Jan 26, 2024
1 parent 4b198c2 commit 3a2c556
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions tests/test_training/test_load_ckpt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import shutil
import socket

import numpy as np
import pytest
Expand Down Expand Up @@ -127,12 +128,18 @@
)


def build_environment(rank, world_size, config):
def find_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


def build_environment(rank, world_size, free_port, config):
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "33333"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(free_port)
torch.cuda.empty_cache()
# launcher="torch"
internlm.launch_from_torch(config=config, seed=1024)
Expand All @@ -156,8 +163,8 @@ def seed_all(seed, cuda_deterministic=False):

def train_model(args):
# init
rank, world_size, train_round = args
build_environment(rank, world_size, config)
rank, world_size, train_round, free_port = args
build_environment(rank, world_size, free_port, config)
total_steps = 6

if train_round == 1:
Expand Down Expand Up @@ -286,12 +293,13 @@ def train_model(args):

def test_loss():
results = []
free_port = find_free_port()
ctx = mp.get_context("spawn")
for train_round in range(2):
with ctx.Pool(processes=8) as pool:
result = pool.map(
train_model,
[[rank, 8, train_round] for rank in range(8)],
[[rank, 8, train_round, free_port] for rank in range(8)],
)
results.append(result)
pool.close()
Expand Down

0 comments on commit 3a2c556

Please sign in to comment.