1
1
"""Util constants/functions for the backends."""
2
+ import colorama
2
3
import datetime
3
4
import enum
4
5
import getpass
5
6
import os
6
7
import pathlib
7
8
import shlex
8
9
import subprocess
10
+ import sys
9
11
import textwrap
10
12
import time
11
13
from typing import Dict , List , Optional , Tuple , Union
@@ -598,12 +600,12 @@ def run_command_on_ip_via_ssh(
598
600
ssh_private_key : str ,
599
601
port_forward : Optional [List [int ]] = None ,
600
602
# Advanced options.
603
+ require_outputs : bool = False ,
601
604
log_path : str = '/dev/null' ,
602
605
stream_logs : bool = True ,
603
- check : bool = False ,
604
606
ssh_mode : SshMode = SshMode .NON_INTERACTIVE ,
605
607
ssh_control_name : Optional [str ] = None ,
606
- ) -> Tuple [subprocess . Popen , str , str ]:
608
+ ) -> Union [ int , Tuple [int , str , str ] ]:
607
609
"""Uses 'ssh' to run 'cmd' on a node with ip.
608
610
609
611
Args:
@@ -616,6 +618,7 @@ def run_command_on_ip_via_ssh(
616
618
617
619
Advanced options:
618
620
621
+ require_outputs: Whether to return the stdout/stderr of the command.
619
622
log_path: Redirect stdout/stderr to the log_path.
620
623
stream_logs: Stream logs to the stdout/stderr.
621
624
check: Check the success of the command.
@@ -625,7 +628,9 @@ def run_command_on_ip_via_ssh(
625
628
for optimizing the ssh speed.
626
629
627
630
Returns:
628
- A tuple of (process, stdout, stderr).
631
+ returncode
632
+ or
633
+ A tuple of (returncode, stdout, stderr).
629
634
"""
630
635
base_ssh_command = _ssh_base_command (ip ,
631
636
ssh_private_key ,
@@ -636,8 +641,8 @@ def run_command_on_ip_via_ssh(
636
641
if ssh_mode == SshMode .LOGIN :
637
642
assert isinstance (cmd , list ), 'cmd must be a list for login mode.'
638
643
command = base_ssh_command + cmd
639
- proc = run (command , shell = False , check = check )
640
- return proc , '' , ''
644
+ proc = run (command , shell = False , check = False )
645
+ return proc . returncode , '' , ''
641
646
if isinstance (cmd , list ):
642
647
cmd = ' ' .join (cmd )
643
648
# We need this to correctly run the cmd, and get the output.
@@ -652,7 +657,31 @@ def run_command_on_ip_via_ssh(
652
657
shlex .quote (f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 '
653
658
f'PYTHONWARNINGS=ignore && ({ cmd } )' ),
654
659
]
655
- return log_lib .run_with_log (command , log_path , stream_logs , check = check )
660
+ return log_lib .run_with_log (command ,
661
+ log_path ,
662
+ stream_logs ,
663
+ require_outputs = require_outputs )
664
+
665
+
666
+ def handle_returncode (returncode : int ,
667
+ command : str ,
668
+ error_msg : str ,
669
+ stderr : Optional [str ] = None ) -> None :
670
+ """Handle the returncode of a command.
671
+
672
+ Args:
673
+ returncode: The returncode of the command.
674
+ command: The command that was run.
675
+ error_msg: The error message to print.
676
+ stderr: The stderr of the command.
677
+ """
678
+ if returncode != 0 :
679
+ if stderr is not None :
680
+ logger .error (stderr )
681
+ logger .error (f'Command failed with code { returncode } : { command } ' )
682
+ logger .error (
683
+ f'{ colorama .Fore .RED } { error_msg } { colorama .Style .RESET_ALL } ' )
684
+ sys .exit (returncode )
656
685
657
686
658
687
def run (cmd , ** kwargs ):
0 commit comments