diff --git a/sky/authentication.py b/sky/authentication.py index aa9f336c27a..53fc27fd1c1 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -37,6 +37,7 @@ from sky import clouds from sky import sky_logging +from sky import skypilot_config from sky.adaptors import gcp from sky.adaptors import ibm from sky.skylet.providers.lambda_cloud import lambda_utils @@ -378,8 +379,13 @@ def setup_scp_authentication(config: Dict[str, Any]) -> Dict[str, Any]: def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: + # Default ssh session is established with kubectl port-forwarding with + # ClusterIP service + nodeport_mode = kubernetes_utils.KubernetesNetworkingMode.NODEPORT + port_forward_mode = kubernetes_utils.KubernetesNetworkingMode.PORT_FORWARD + ssh_setup_mode = skypilot_config.get_nested(('kubernetes', 'networking'), + port_forward_mode.value) get_or_generate_keys() - # Run kubectl command to add the public key to the cluster. public_key_path = os.path.expanduser(PUBLIC_SSH_KEY_PATH) key_label = clouds.Kubernetes.SKY_SSH_KEY_SECRET_NAME @@ -404,16 +410,36 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: logger.error(suffix) raise + ssh_jump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME + if ssh_setup_mode.lower() == nodeport_mode.value.lower(): + network_mode = nodeport_mode + service_type = kubernetes_utils.KubernetesServiceType.NODEPORT + + elif ssh_setup_mode.lower() == port_forward_mode.value.lower(): + kubernetes_utils.check_port_forward_mode_dependencies() + network_mode = port_forward_mode + # Using `kubectl port-forward` creates a direct tunnel to jump pod and + # does not require opening any ports on Kubernetes nodes. As a result, + # the service can be a simple ClusterIP service which we access with + # `kubectl port-forward`. + service_type = kubernetes_utils.KubernetesServiceType.CLUSTERIP + else: + raise ValueError(f'Unsupported kubernetes networking mode: ' + f'{ssh_setup_mode}. The mode has to be either ' + f'\'{port_forward_mode.value}\' or ' + f'\'{nodeport_mode.value}\'. ' + 'Please check: ~/.sky/config.yaml') # Setup service for SSH jump pod. We create the SSH jump service here # because we need to know the service IP address and port to set the # ssh_proxy_command in the autoscaler config. namespace = kubernetes_utils.get_current_kube_config_context_namespace() - sshjump_name = clouds.Kubernetes.SKY_SSH_JUMP_NAME - - kubernetes_utils.setup_sshjump_svc(sshjump_name, namespace) + kubernetes_utils.setup_sshjump_svc(ssh_jump_name, namespace, service_type) ssh_proxy_cmd = kubernetes_utils.get_ssh_proxy_command( - PRIVATE_SSH_KEY_PATH, sshjump_name, namespace) + PRIVATE_SSH_KEY_PATH, ssh_jump_name, network_mode, namespace, + clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_PATH, + clouds.Kubernetes.PORT_FORWARD_PROXY_CMD_TEMPLATE) config['auth']['ssh_proxy_command'] = ssh_proxy_cmd + return config diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index 1d793496cc4..f41c45e545f 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -1355,7 +1355,7 @@ def wait_until_ray_cluster_ready( def ssh_credential_from_yaml(cluster_yaml: str, docker_user: Optional[str] = None - ) -> Dict[str, str]: + ) -> Dict[str, Any]: """Returns ssh_user, ssh_private_key and ssh_control name.""" config = common_utils.read_yaml(cluster_yaml) auth_section = config['auth'] @@ -1371,6 +1371,10 @@ def ssh_credential_from_yaml(cluster_yaml: str, } if docker_user is not None: credentials['docker_user'] = docker_user + ssh_provider_module = config['provider']['module'] + # If we are running ssh command on kubernetes node. + if 'kubernetes' in ssh_provider_module: + credentials['disable_control_master'] = True return credentials diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 30377f26881..0a034261dbf 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2988,37 +2988,12 @@ def _sync_file_mounts( self._execute_file_mounts(handle, all_file_mounts) self._execute_storage_mounts(handle, storage_mounts) - def _update_envs_for_k8s(self, handle: CloudVmRayResourceHandle, - task: task_lib.Task) -> None: - """Update envs with env vars from Kubernetes if cloud is Kubernetes. - - Kubernetes automatically populates containers with critical environment - variables, such as those for discovering services running in the - cluster and CUDA/nvidia environment variables. We need to update task - environment variables with these env vars. This is needed for GPU - support and service discovery. - - See https://github.com/skypilot-org/skypilot/issues/2287 for - more details. - """ - if isinstance(handle.launched_resources.cloud, clouds.Kubernetes): - temp_envs = copy.deepcopy(task.envs) - cloud_env_vars = handle.launched_resources.cloud.query_env_vars( - handle.cluster_name_on_cloud) - task.update_envs(cloud_env_vars) - - # Re update the envs with the original envs to give priority to - # the original envs. - task.update_envs(temp_envs) - def _setup(self, handle: CloudVmRayResourceHandle, task: task_lib.Task, detach_setup: bool) -> None: start = time.time() style = colorama.Style fore = colorama.Fore - self._update_envs_for_k8s(handle, task) - if task.setup is None: return @@ -3327,7 +3302,6 @@ def _execute( # Check the task resources vs the cluster resources. Since `sky exec` # will not run the provision and _check_existing_cluster self.check_resources_fit_cluster(handle, task) - self._update_envs_for_k8s(handle, task) resources_str = backend_utils.get_task_resources_str(task) diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index 9855b5410ac..0963215452d 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -19,7 +19,7 @@ logger = sky_logging.init_logger(__name__) -_CREDENTIAL_PATH = '~/.kube/config' +CREDENTIAL_PATH = '~/.kube/config' @clouds.CLOUD_REGISTRY.register @@ -28,7 +28,9 @@ class Kubernetes(clouds.Cloud): SKY_SSH_KEY_SECRET_NAME = f'sky-ssh-{common_utils.get_user_hash()}' SKY_SSH_JUMP_NAME = f'sky-sshjump-{common_utils.get_user_hash()}' - + PORT_FORWARD_PROXY_CMD_TEMPLATE = \ + 'kubernetes-port-forward-proxy-command.sh.j2' + PORT_FORWARD_PROXY_CMD_PATH = '~/.sky/port-forward-proxy-cmd.sh' # Timeout for resource provisioning. This timeout determines how long to # wait for pod to be in pending status before giving up. # Larger timeout may be required for autoscaling clusters, since autoscaler @@ -296,7 +298,7 @@ def _make(instance_list): @classmethod def check_credentials(cls) -> Tuple[bool, Optional[str]]: - if os.path.exists(os.path.expanduser(_CREDENTIAL_PATH)): + if os.path.exists(os.path.expanduser(CREDENTIAL_PATH)): # Test using python API try: return kubernetes_utils.check_credentials() @@ -305,10 +307,10 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: f'{common_utils.format_exception(e)}') else: return (False, 'Credentials not found - ' - f'check if {_CREDENTIAL_PATH} exists.') + f'check if {CREDENTIAL_PATH} exists.') def get_credential_file_mounts(self) -> Dict[str, str]: - return {_CREDENTIAL_PATH: _CREDENTIAL_PATH} + return {CREDENTIAL_PATH: CREDENTIAL_PATH} def instance_type_exists(self, instance_type: str) -> bool: return kubernetes_utils.KubernetesInstanceType.is_valid_instance_type( @@ -365,36 +367,3 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], cluster_status.append(status_lib.ClusterStatus.INIT) # If pods are not found, we don't add them to the return list return cluster_status - - @classmethod - def query_env_vars(cls, name: str) -> Dict[str, str]: - namespace = kubernetes_utils.get_current_kube_config_context_namespace() - pod = kubernetes.core_api().list_namespaced_pod( - namespace, - label_selector=f'skypilot-cluster={name},ray-node-type=head' - ).items[0] - response = kubernetes.stream()( - kubernetes.core_api().connect_get_namespaced_pod_exec, - pod.metadata.name, - namespace, - command=['env'], - stderr=True, - stdin=False, - stdout=True, - tty=False, - _request_timeout=kubernetes.API_TIMEOUT) - # Split response by newline and filter lines containing '=' - raw_lines = response.split('\n') - filtered_lines = [line for line in raw_lines if '=' in line] - - # Split each line at the first '=' occurrence - lines = [line.split('=', 1) for line in filtered_lines] - - # Construct the dictionary using only valid environment variable names - env_vars = {} - for line in lines: - key = line[0] - if common_utils.is_valid_env_var(key): - env_vars[key] = line[1] - - return env_vars diff --git a/sky/skylet/providers/kubernetes/node_provider.py b/sky/skylet/providers/kubernetes/node_provider.py index 82fab8c9c02..8e44379d19e 100644 --- a/sky/skylet/providers/kubernetes/node_provider.py +++ b/sky/skylet/providers/kubernetes/node_provider.py @@ -214,10 +214,12 @@ def create_node(self, node_config, tags, count): 'Cluster may be out of resources or ' 'may be too slow to autoscale.') all_ready = True - + pods_and_containers_running = False + pods = [] for node in new_nodes: pod = kubernetes.core_api().read_namespaced_pod( node.metadata.name, self.namespace) + pods.append(pod) if pod.status.phase == 'Pending': # Iterate over each pod to check their status if pod.status.container_statuses is not None: @@ -237,10 +239,44 @@ def create_node(self, node_config, tags, count): # If container_statuses is None, then the pod hasn't # been scheduled yet. all_ready = False - if all_ready: + + # check if all the pods and containers within the pods are running + if all([ pod.status.phase == "Running" for pod in pods]) \ + and all([container.state.running for pod in pods for container in pod.status.container_statuses]): + pods_and_containers_running = True + + if all_ready and pods_and_containers_running: break time.sleep(1) + # Kubernetes automatically populates containers with critical + # environment variables, such as those for discovering services running + # in the cluster and CUDA/nvidia environment variables. We need to + # update task environment variables with these env vars. This is needed + # for GPU support and service discovery. + # See https://github.com/skypilot-org/skypilot/issues/2287 for + # more details. + # Capturing env. var. from the pod's runtime and writes them to + # /etc/profile.d/ making them available for all users in future + # shell sessions. + set_k8s_env_var_cmd = [ + '/bin/sh', '-c', + ('printenv | awk -F "=" \'{print "export " $1 "=\\047" $2 "\\047"}\' > ~/k8s_env_var.sh;' + 'mv ~/k8s_env_var.sh /etc/profile.d/k8s_env_var.sh || ' + 'sudo mv ~/k8s_env_var.sh /etc/profile.d/k8s_env_var.sh') + ] + for new_node in new_nodes: + kubernetes.stream()( + kubernetes.core_api().connect_get_namespaced_pod_exec, + new_node.metadata.name, + self.namespace, + command=set_k8s_env_var_cmd, + stderr=True, + stdin=False, + stdout=True, + tty=False, + _request_timeout=kubernetes.API_TIMEOUT) + def terminate_node(self, node_id): logger.info(config.log_prefix + 'calling delete_namespaced_pod') try: diff --git a/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 b/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 new file mode 100644 index 00000000000..fa71df3a0ec --- /dev/null +++ b/sky/templates/kubernetes-port-forward-proxy-command.sh.j2 @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +set -uo pipefail + +# Checks if socat is installed +if ! command -v socat > /dev/null; then + echo "Using 'port-forward' mode to run ssh session on Kubernetes instances requires 'socat' to be installed. Please install 'socat'" >&2 + exit +fi + +# Checks if lsof is installed +if ! command -v lsof > /dev/null; then + echo "Checking port availability for 'port-forward' mode requires 'lsof' to be installed. Please install 'lsof'" >&2 + exit 1 +fi + +# Function to check if port is in use +is_port_in_use() { + local port="$1" + lsof -i :${port} > /dev/null 2>&1 +} + +# Start from a fixed local port and increment if in use +local_port={{ local_port }} +while is_port_in_use "${local_port}"; do + local_port=$((local_port + 1)) +done + +# Establishes connection between local port and the ssh jump pod +kubectl port-forward svc/{{ ssh_jump_name }} "${local_port}":22 & + +# Terminate the port-forward process when this script exits. +K8S_PORT_FWD_PID=$! +trap "kill $K8S_PORT_FWD_PID" EXIT + +# checks if a connection to local_port of 127.0.0.1:[local_port] is established +while ! nc -z 127.0.0.1 "${local_port}"; do + sleep 0.1 +done + +# Establishes two directional byte streams to handle stdin/stdout between +# terminal and the jump pod. +# socat process terminates when port-forward terminates. +socat - tcp:127.0.0.1:"${local_port}" \ No newline at end of file diff --git a/sky/templates/kubernetes-sshjump.yml.j2 b/sky/templates/kubernetes-sshjump.yml.j2 index d2844034263..a0e353b08ca 100644 --- a/sky/templates/kubernetes-sshjump.yml.j2 +++ b/sky/templates/kubernetes-sshjump.yml.j2 @@ -50,14 +50,13 @@ service_spec: name: {{ name }} parent: skypilot spec: - type: NodePort + type: {{ service_type }} selector: component: {{ name }} ports: - protocol: TCP port: 22 targetPort: 22 - # The following ServiceAccount/Role/RoleBinding sets up an RBAC for life cycle # management of the jump pod/service service_account: diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 08fde49354d..7969694632c 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -42,13 +42,16 @@ def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]: return path -def ssh_options_list(ssh_private_key: Optional[str], - ssh_control_name: Optional[str], - *, - ssh_proxy_command: Optional[str] = None, - docker_ssh_proxy_command: Optional[str] = None, - timeout: int = 30, - port: int = 22) -> List[str]: +def ssh_options_list( + ssh_private_key: Optional[str], + ssh_control_name: Optional[str], + *, + ssh_proxy_command: Optional[str] = None, + docker_ssh_proxy_command: Optional[str] = None, + timeout: int = 30, + port: int = 22, + disable_control_master: Optional[bool] = False, +) -> List[str]: """Returns a list of sane options for 'ssh'.""" # Forked from Ray SSHOptions: # https://github.com/ray-project/ray/blob/master/python/ray/autoscaler/_private/command_runner.py @@ -79,7 +82,13 @@ def ssh_options_list(ssh_private_key: Optional[str], } # SSH Control will have a severe delay when using docker_ssh_proxy_command. # TODO(tian): Investigate why. - if ssh_control_name is not None and docker_ssh_proxy_command is None: + # We also do not use ControlMaster when we use `kubectl port-forward` + # to access Kubernetes pods over SSH+Proxycommand. This is because the + # process running ProxyCommand is kept running as long as the ssh session + # is running and the ControlMaster keeps the session, which results in + # 'ControlPersist' number of seconds delay per ssh commands ran. + if ssh_control_name is not None and docker_ssh_proxy_command is None \ + and not disable_control_master: arg_dict.update({ # Control path: important optimization as we do multiple ssh in one # sky.launch(). @@ -136,6 +145,7 @@ def __init__( ssh_proxy_command: Optional[str] = None, port: int = 22, docker_user: Optional[str] = None, + disable_control_master: Optional[bool] = False, ): """Initialize SSHCommandRunner. @@ -158,13 +168,17 @@ def __init__( port: The port to use for ssh. docker_user: The docker user to use for ssh. If specified, the command will be run inside a docker container which have a ssh - server running at port sky.skylet.constants.DEFAULT_DOCKER_PORT. + server running at port sky.skylet.constants.DEFAULT_DOCKER_PORT + disable_control_master: bool; specifies either or not the ssh + command will utilize ControlMaster. We currently disable + it for k8s instance. """ self.ssh_private_key = ssh_private_key self.ssh_control_name = ( None if ssh_control_name is None else hashlib.md5( ssh_control_name.encode()).hexdigest()[:_HASH_MAX_LENGTH]) self._ssh_proxy_command = ssh_proxy_command + self.disable_control_master = disable_control_master if docker_user is not None: assert port is None or port == 22, ( f'port must be None or 22 for docker_user, got {port}.') @@ -190,6 +204,7 @@ def make_runner_list( ssh_private_key: str, ssh_control_name: Optional[str] = None, ssh_proxy_command: Optional[str] = None, + disable_control_master: Optional[bool] = False, port_list: Optional[List[int]] = None, docker_user: Optional[str] = None, ) -> List['SSHCommandRunner']: @@ -198,7 +213,8 @@ def make_runner_list( port_list = [22] * len(ip_list) return [ SSHCommandRunner(ip, ssh_user, ssh_private_key, ssh_control_name, - ssh_proxy_command, port, docker_user) + ssh_proxy_command, port, docker_user, + disable_control_master) for ip, port in zip(ip_list, port_list) ] @@ -228,7 +244,9 @@ def _ssh_base_command(self, *, ssh_mode: SshMode, ssh_proxy_command=self._ssh_proxy_command, docker_ssh_proxy_command=docker_ssh_proxy_command, port=self.port, - ) + [f'{self.ssh_user}@{self.ip}'] + disable_control_master=self.disable_control_master) + [ + f'{self.ssh_user}@{self.ip}' + ] def run( self, @@ -388,7 +406,7 @@ def rsync( ssh_proxy_command=self._ssh_proxy_command, docker_ssh_proxy_command=docker_ssh_proxy_command, port=self.port, - )) + disable_control_master=self.disable_control_master)) rsync_command.append(f'-e "ssh {ssh_options}"') # To support spaces in the path, we need to quote source and target. # rsync doesn't support '~' in a quoted local path, but it is ok to diff --git a/sky/utils/command_runner.pyi b/sky/utils/command_runner.pyi index e5feb5fb8db..425f5c60213 100644 --- a/sky/utils/command_runner.pyi +++ b/sky/utils/command_runner.pyi @@ -20,10 +20,13 @@ RSYNC_FILTER_OPTION: str RSYNC_EXCLUDE_OPTION: str -def ssh_options_list(ssh_private_key: Optional[str], - ssh_control_name: Optional[str], - *, - timeout: int = ...) -> List[str]: +def ssh_options_list( + ssh_private_key: Optional[str], + ssh_control_name: Optional[str], + *, + timeout: int = ..., + disable_control_master: Optional[bool] = False, +) -> List[str]: ... @@ -40,14 +43,18 @@ class SSHCommandRunner: ssh_control_name: Optional[str] docker_user: str port: int + disable_control_master: Optional[bool] - def __init__(self, - ip: str, - ssh_user: str, - ssh_private_key: str, - ssh_control_name: Optional[str] = ..., - port: int = ..., - docker_user: Optional[str] = ...) -> None: + def __init__( + self, + ip: str, + ssh_user: str, + ssh_private_key: str, + ssh_control_name: Optional[str] = ..., + port: int = ..., + docker_user: Optional[str] = ..., + disable_control_master: Optional[bool] = ..., + ) -> None: ... @staticmethod @@ -59,6 +66,7 @@ class SSHCommandRunner: ssh_proxy_command: Optional[str] = ..., port_list: Optional[List[int]] = ..., docker_user: Optional[str] = ..., + disable_control_master: Optional[bool] = ..., ) -> List['SSHCommandRunner']: ... diff --git a/sky/utils/kubernetes_utils.py b/sky/utils/kubernetes_utils.py index 78771d4cc5f..7b3f762705a 100644 --- a/sky/utils/kubernetes_utils.py +++ b/sky/utils/kubernetes_utils.py @@ -1,7 +1,9 @@ """Kubernetes utilities for SkyPilot.""" +import enum import math import os import re +import subprocess from typing import Any, Dict, List, Optional, Set, Tuple, Union from urllib.parse import urlparse @@ -12,11 +14,13 @@ from sky import exceptions from sky import sky_logging from sky.adaptors import kubernetes +from sky.backends import backend_utils from sky.utils import common_utils from sky.utils import env_options from sky.utils import ux_utils DEFAULT_NAMESPACE = 'default' +LOCAL_PORT_FOR_PORT_FORWARD = 23100 MEMORY_SIZE_UNITS = { 'B': 1, @@ -30,6 +34,20 @@ logger = sky_logging.init_logger(__name__) +class KubernetesNetworkingMode(enum.Enum): + """Enum for the different types of networking modes for accessing + jump pods. + """ + NODEPORT = 'NodePort' + PORT_FORWARD = 'Port_Forward' + + +class KubernetesServiceType(enum.Enum): + """Enum for the different types of services.""" + NODEPORT = 'NodePort' + CLUSTERIP = 'ClusterIP' + + class GPULabelFormatter: """Base class to define a GPU label formatter for a Kubernetes cluster @@ -355,7 +373,9 @@ def get_port(svc_name: str, namespace: str) -> int: return head_service.spec.ports[0].node_port -def get_external_ip(): +def get_external_ip(network_mode: Optional[KubernetesNetworkingMode]): + if network_mode == KubernetesNetworkingMode.PORT_FORWARD: + return '127.0.0.1' # Return the IP address of the first node with an external IP nodes = kubernetes.core_api().list_node().items for node in nodes: @@ -603,30 +623,97 @@ def __str__(self): return self.name -def get_ssh_proxy_command(private_key_path: str, sshjump_name: str, - namespace: str) -> str: +def construct_ssh_jump_command(private_key_path: str, + ssh_jump_port: int, + ssh_jump_ip: str, + proxy_cmd_path: Optional[str] = None) -> str: + ssh_jump_proxy_command = (f'ssh -tt -i {private_key_path} ' + '-o StrictHostKeyChecking=no ' + '-o UserKnownHostsFile=/dev/null ' + f'-o IdentitiesOnly=yes -p {ssh_jump_port} ' + f'-W %h:%p sky@{ssh_jump_ip}') + if proxy_cmd_path is not None: + proxy_cmd_path = os.path.expanduser(proxy_cmd_path) + # adding execution permission to the proxy command script + os.chmod(proxy_cmd_path, os.stat(proxy_cmd_path).st_mode | 0o111) + ssh_jump_proxy_command += f' -o ProxyCommand=\'{proxy_cmd_path}\' ' + return ssh_jump_proxy_command + + +def get_ssh_proxy_command(private_key_path: str, ssh_jump_name: str, + network_mode: KubernetesNetworkingMode, + namespace: str, port_fwd_proxy_cmd_path: str, + port_fwd_proxy_cmd_template: str) -> str: """Generates the SSH proxy command to connect through the SSH jump pod. + By default, establishing an SSH connection creates a communication + channel to a remote node by setting up a TCP connection. When a + ProxyCommand is specified, this default behavior is overridden. The command + specified in ProxyCommand is executed, and its standard input and output + become the communication channel for the SSH session. + + Pods within a Kubernetes cluster have internal IP addresses that are + typically not accessible from outside the cluster. Since the default TCP + connection of SSH won't allow access to these pods, we employ a + ProxyCommand to establish the required communication channel. We offer this + in two different networking options: NodePort/port-forward. + + With the NodePort networking mode, a NodePort service is launched. This + service opens an external port on the node which redirects to the desired + port within the pod. When establishing an SSH session in this mode, the + ProxyCommand makes use of this external port to create a communication + channel directly to port 22, which is the default port ssh server listens + on, of the jump pod. + + With Port-forward mode, instead of directly exposing an external port, + 'kubectl port-forward' sets up a tunnel between a local port + (127.0.0.1:23100) and port 22 of the jump pod. Then we establish a TCP + connection to the local end of this tunnel, 127.0.0.1:23100, using 'socat'. + This is setup in the inner ProxyCommand of the nested ProxyCommand, and the + rest is the same as NodePort approach, which the outer ProxyCommand + establishes a communication channel between 127.0.0.1:23100 and port 22 on + the jump pod. Consequently, any stdin provided on the local machine is + forwarded through this tunnel to the application (SSH server) listening in + the pod. Similarly, any output from the application in the pod is tunneled + back and displayed in the terminal on the local machine. + Args: - private_key_path: Path to the private key to use for SSH. This key must - be authorized to access the SSH jump pod. - sshjump_name: Name of the SSH jump service to use + private_key_path: str; Path to the private key to use for SSH. + This key must be authorized to access the SSH jump pod. + ssh_jump_name: str; Name of the SSH jump service to use + network_mode: KubernetesNetworkingMode; networking mode for ssh + session. It is either 'NODEPORT' or 'PORT_FORWARD' namespace: Kubernetes namespace to use + port_fwd_proxy_cmd_path: str; path to the script used as Proxycommand + with 'kubectl port-forward' + port_fwd_proxy_cmd_template: str; template used to create + 'kubectl port-forward' Proxycommand """ - # Fetch service port and IP to connect to for the jump svc - ssh_jump_port = get_port(sshjump_name, namespace) - ssh_jump_ip = get_external_ip() - - ssh_jump_proxy_command = (f'ssh -tt -i {private_key_path} ' - '-o StrictHostKeyChecking=no ' - '-o UserKnownHostsFile=/dev/null ' - '-o IdentitiesOnly=yes ' - f'-p {ssh_jump_port} -W %h:%p sky@{ssh_jump_ip}') - + # Fetch IP to connect to for the jump svc + ssh_jump_ip = get_external_ip(network_mode) + if network_mode == KubernetesNetworkingMode.NODEPORT: + ssh_jump_port = get_port(ssh_jump_name, namespace) + ssh_jump_proxy_command = construct_ssh_jump_command( + private_key_path, ssh_jump_port, ssh_jump_ip) + # Setting kubectl port-forward/socat to establish ssh session using + # ClusterIP service to disallow any ports opened + else: + ssh_jump_port = LOCAL_PORT_FOR_PORT_FORWARD + vars_to_fill = { + 'ssh_jump_name': ssh_jump_name, + 'local_port': ssh_jump_port, + } + backend_utils.fill_template(port_fwd_proxy_cmd_template, + vars_to_fill, + output_path=port_fwd_proxy_cmd_path) + ssh_jump_proxy_command = construct_ssh_jump_command( + private_key_path, ssh_jump_port, ssh_jump_ip, + port_fwd_proxy_cmd_path) return ssh_jump_proxy_command -def setup_sshjump_svc(sshjump_name: str, namespace: str): +def setup_sshjump_svc(ssh_jump_name: str, namespace: str, + service_type: KubernetesServiceType): """Sets up Kubernetes service resource to access for SSH jump pod. This method acts as a necessary complement to be run along with @@ -635,23 +722,56 @@ def setup_sshjump_svc(sshjump_name: str, namespace: str): Args: sshjump_name: Name to use for the SSH jump service namespace: Namespace to create the SSH jump service in + service_type: Networking configuration on either to use NodePort + or ClusterIP service to ssh in """ # Fill in template - ssh_key_secret and sshjump_image are not required for # the service spec, so we pass in empty strs. - content = fill_sshjump_template('', '', sshjump_name) + content = fill_sshjump_template('', '', ssh_jump_name, service_type.value) # Create service try: kubernetes.core_api().create_namespaced_service(namespace, content['service_spec']) except kubernetes.api_exception() as e: + # SSH Jump Pod service already exists. if e.status == 409: - logger.warning( - f'SSH Jump Service {sshjump_name} already exists in the ' - 'cluster, using it.') + ssh_jump_service = kubernetes.core_api().read_namespaced_service( + name=ssh_jump_name, namespace=namespace) + curr_svc_type = ssh_jump_service.spec.type + if service_type.value == curr_svc_type: + # If the currently existing SSH Jump service's type is identical + # to user's configuration for networking mode + logger.warning( + f'SSH Jump Service {ssh_jump_name} already exists in the ' + 'cluster, using it.') + else: + # If a different type of service type for SSH Jump pod compared + # to user's configuration for networking mode exists, we remove + # existing servie to create a new one following user's config + kubernetes.core_api().delete_namespaced_service( + name=ssh_jump_name, namespace=namespace) + kubernetes.core_api().create_namespaced_service( + namespace, content['service_spec']) + port_forward_mode = KubernetesNetworkingMode.PORT_FORWARD.value + nodeport_mode = KubernetesNetworkingMode.NODEPORT.value + clusterip_svc = KubernetesServiceType.CLUSTERIP.value + nodeport_svc = KubernetesServiceType.NODEPORT.value + curr_network_mode = port_forward_mode \ + if curr_svc_type == clusterip_svc else nodeport_mode + new_network_mode = nodeport_mode \ + if curr_svc_type == clusterip_svc else port_forward_mode + new_svc_type = nodeport_svc \ + if curr_svc_type == clusterip_svc else clusterip_svc + logger.info( + f'Switching the networking mode from ' + f'\'{curr_network_mode}\' to \'{new_network_mode}\' ' + f'following networking configuration. Deleting existing ' + f'\'{curr_svc_type}\' service and recreating as ' + f'\'{new_svc_type}\' service.') else: raise else: - logger.info(f'Created SSH Jump Service {sshjump_name}.') + logger.info(f'Created SSH Jump Service {ssh_jump_name}.') def setup_sshjump_pod(sshjump_name: str, sshjump_image: str, @@ -673,7 +793,10 @@ def setup_sshjump_pod(sshjump_name: str, sshjump_image: str, ssh_key_secret: Secret name for the SSH key stored in the cluster namespace: Namespace to create the SSH jump pod in """ - content = fill_sshjump_template(ssh_key_secret, sshjump_image, sshjump_name) + # Fill in template - service is created separately so service_type is not + # required, so we pass in empty str. + content = fill_sshjump_template(ssh_key_secret, sshjump_image, sshjump_name, + '') # ServiceAccount try: kubernetes.core_api().create_namespaced_service_account( @@ -784,7 +907,7 @@ def find(l, predicate): def fill_sshjump_template(ssh_key_secret: str, sshjump_image: str, - sshjump_name: str) -> Dict: + sshjump_name: str, service_type: str) -> Dict: template_path = os.path.join(sky.__root_dir__, 'templates', 'kubernetes-sshjump.yml.j2') if not os.path.exists(template_path): @@ -795,6 +918,25 @@ def fill_sshjump_template(ssh_key_secret: str, sshjump_image: str, j2_template = jinja2.Template(template) cont = j2_template.render(name=sshjump_name, image=sshjump_image, - secret=ssh_key_secret) + secret=ssh_key_secret, + service_type=service_type) content = yaml.safe_load(cont) return content + + +def check_port_forward_mode_dependencies() -> None: + """Checks if 'socat' and 'lsof' is installed""" + for name, option in [('socat', '-V'), ('lsof', '-v')]: + try: + subprocess.run([name, option], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + check=True) + except FileNotFoundError: + with ux_utils.print_exception_no_traceback(): + raise RuntimeError( + f'`{name}` is required to setup Kubernetes cloud with ' + f'`{KubernetesNetworkingMode.PORT_FORWARD.value}` default ' + 'networking mode and it is not installed. ' + 'For Debian/Ubuntu system, install it with:\n' + f' $ sudo apt install {name}') from None diff --git a/tests/test_config.py b/tests/test_config.py index dffd4843ffd..cb0ab42df78 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,9 +6,12 @@ from sky import skypilot_config from sky.utils import common_utils +from sky.utils import kubernetes_utils VPC_NAME = 'vpc-12345678' PROXY_COMMAND = 'ssh -W %h:%p -i ~/.ssh/id_rsa -o StrictHostKeyChecking=no' +NODEPORT_MODE_NAME = kubernetes_utils.KubernetesNetworkingMode.NODEPORT.value +PORT_FORWARD_MODE_NAME = kubernetes_utils.KubernetesNetworkingMode.PORT_FORWARD.value def _reload_config() -> None: @@ -34,6 +37,8 @@ def _create_config_file(config_file_path: pathlib.Path) -> None: vpc_name: {VPC_NAME} use_internal_ips: true ssh_proxy_command: {PROXY_COMMAND} + kubernetes: + networking: {NODEPORT_MODE_NAME} """)) @@ -67,14 +72,19 @@ def test_config_get_set_nested(monkeypatch, tmp_path) -> None: assert skypilot_config.get_nested(('aws', 'use_internal_ips'), None) assert skypilot_config.get_nested(('aws', 'ssh_proxy_command'), None) == PROXY_COMMAND - + assert skypilot_config.get_nested(('kubernetes', 'networking'), + None) == NODEPORT_MODE_NAME # Check set_nested() will copy the config dict and return a new dict new_config = skypilot_config.set_nested(('aws', 'ssh_proxy_command'), 'new_value') assert new_config['aws']['ssh_proxy_command'] == 'new_value' assert skypilot_config.get_nested(('aws', 'ssh_proxy_command'), None) == PROXY_COMMAND - + new_config = skypilot_config.set_nested(('kubernetes', 'networking'), + PORT_FORWARD_MODE_NAME) + assert new_config['aws']['ssh_proxy_command'] == PORT_FORWARD_MODE_NAME + assert skypilot_config.get_nested(('kubernetes', 'networking'), + None) == NODEPORT_MODE_NAME # Check that dumping the config to a file with the new None can be reloaded new_config2 = skypilot_config.set_nested(('aws', 'ssh_proxy_command'), None) new_config_path = tmp_path / 'new_config.yaml'