Skip to content

Commit

Permalink
[ROCm]: Add support to continue on fail
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Batra committed Feb 16, 2024
1 parent 8a11b40 commit f6d3162
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
4 changes: 2 additions & 2 deletions build/rocm/run_multi_gpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ run_tests() {
local base_dir=./logs
local gpu_devices="$1"
export HIP_VISIBLE_DEVICES=$gpu_devices
python3 -m pytest --html=$base_dir/multi_gpu_pmap_test_log.html --reruns 3 -x tests/pmap_test.py
python3 -m pytest --html=$base_dir/multi_gpu_multi_device_test_log.html --reruns 3 -x tests/multi_device_test.py
python3 -m pytest --html=$base_dir/multi_gpu_pmap_test_log.html --reruns 3 tests/pmap_test.py
python3 -m pytest --html=$base_dir/multi_gpu_multi_device_test_log.html --reruns 3 tests/multi_device_test.py
python3 -m pytest_html_merger -i $base_dir/ -o $base_dir/final_compiled_report.html
}

Expand Down
30 changes: 21 additions & 9 deletions build/rocm/run_single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def collect_testmodules():
print("Test module discovery failed.")
exit(return_code)
for line in stdout.split("\n"):
match = re.match("<Module (.*)>", line)
match = re.match("<Module (.*)>", line.strip())
if match:
test_file = match.group(1)
if "/" not in test_file:
test_file = os.path.join("tests",test_file)
all_test_files.append(test_file)
print("---------- collected test modules ----------")
print("Found %d test modules." % (len(all_test_files)))
Expand All @@ -79,7 +81,7 @@ def collect_testmodules():
return all_test_files


def run_test(testmodule, gpu_tokens):
def run_test(testmodule, gpu_tokens, continue_on_fail):
global LAST_CODE
with GPU_LOCK:
if LAST_CODE != 0:
Expand All @@ -90,39 +92,43 @@ def run_test(testmodule, gpu_tokens):
"XLA_PYTHON_CLIENT_ALLOCATOR": "default",
}
testfile = extract_filename(testmodule)
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", testmodule]
if continue_on_fail:
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-v", testmodule]
else:
cmd = ["python3", "-m", "pytest", '--html={}/{}_log.html'.format(base_dir, testfile), "--reruns", "3", "-x", "-v", testmodule]
return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars)
with GPU_LOCK:
gpu_tokens.append(target_gpu)
if LAST_CODE == 0:
print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu))
print(stdout)
print(stderr)
LAST_CODE = return_code
if continue_on_fail == False:
LAST_CODE = return_code
return


def run_parallel(all_testmodules, p):
print("Running tests with parallelism=", p)
def run_parallel(all_testmodules, p, c):
print(f"Running tests with parallelism=", p)
available_gpu_tokens = list(range(p))
executor = ThreadPoolExecutor(max_workers=p)
# walking through test modules
for testmodule in all_testmodules:
executor.submit(run_test, testmodule, available_gpu_tokens)
executor.submit(run_test, testmodule, available_gpu_tokens, c)
# waiting for all modules to finish
executor.shutdown(wait=True) # wait for all jobs to finish
return


def find_num_gpus():
cmd = ["lspci|grep 'controller'|grep 'AMD/ATI'|wc -l"]
cmd = ["lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"]
_, _, stdout = run_shell_command(cmd, shell=True)
return int(stdout)


def main(args):
all_testmodules = collect_testmodules()
run_parallel(all_testmodules, args.parallel)
run_parallel(all_testmodules, args.parallel, args.continue_on_fail)
generate_final_report()
exit(LAST_CODE)

Expand All @@ -134,7 +140,13 @@ def main(args):
"--parallel",
type=int,
help="number of tests to run in parallel")
parser.add_argument("-c",
"--continue_on_fail",
action='store_true',
help="continue on failure")
args = parser.parse_args()
if args.continue_on_fail:
print("continue on fail is set")
if args.parallel is None:
sys_gpu_count = find_num_gpus()
args.parallel = sys_gpu_count
Expand Down

0 comments on commit f6d3162

Please sign in to comment.