Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
[Rest Server] Generate random ports for scheduling (#3224)
Browse files Browse the repository at this point in the history
* Generate random ports for scheduling

Generate random ports for scheduling.

* Add container ports in task detail

Add container ports in task detail.
  • Loading branch information
abuccts authored Jul 23, 2019
1 parent 193b425 commit 84354bf
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 115 deletions.
181 changes: 86 additions & 95 deletions src/kube-runtime/src/parse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python
#!/usr/bin/python

# Copyright (c) Microsoft Corporation
# All rights reserved.
Expand All @@ -17,112 +17,103 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from __future__ import print_function

import os
import sys
import collections
import logging
import argparse

import json

log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)

def get_container_port(envs, name):
for env in envs:
if env["name"] == name:
return env["value"]
return None

def export(k, v):
print "export %s='%s'" % (k, v)


# generate runtime environment variables:
# PAI_CURRENT_TASK_ROLE_CURRENT_TASK_INDEX
# PAI_HOST_IP_$taskRole_$taskIndex

# These two variables are legacy, subject to removal
# PAI_TASK_ROLE_$name_HOST_LIST
# PAI_$taskRole_$currentTaskIndex_$type_PORT
def gen_runtime_env(framework):
index_id = os.environ.get("FC_TASK_INDEX")

if index_id is None:
log.error("expect FC_TASK_INDEX set as environment variable")
else:
export("PAI_CURRENT_TASK_ROLE_CURRENT_TASK_INDEX", index_id)

log.info("loading json from %s", args.framework_json)

# key is role_name, value is its PAI_CURRENT_CONTAINER_PORT val
role_cur_port_map = {}
# key is role_name, value is its PAI_CONTAINER_HOST_PORT_LIST val
role_ports_map = {}

cur_task_role_name = os.environ.get("FC_TASKROLE_NAME")

# key is role_name, value is task count in this role
role_task_cnt = {}

for idx, role in enumerate(framework["spec"]["taskRoles"]):
role_name = role["name"]
if role_name == cur_task_role_name:
export("PAI_TASK_ROLE_INDEX", idx) # TODO legacy environment
cur_port = get_container_port(
role["task"]["pod"]["spec"]["containers"][0]["env"],
"PAI_CURRENT_CONTAINER_PORT")
ports = get_container_port(
role["task"]["pod"]["spec"]["containers"][0]["env"],
"PAI_CONTAINER_HOST_PORT_LIST")

role_cur_port_map[role_name] = cur_port
role_ports_map[role_name] = ports
role_task_cnt[role_name] = role["taskNumber"]

log.info("role_cur_port_map is %s, role_ports_map is %s",
role_cur_port_map, role_ports_map)

# key is role name, value is a map with key of index, value of ip
role_status_map = collections.defaultdict(lambda : {})

for role_status in framework["status"]["attemptStatus"]["taskRoleStatuses"]:
name = role_status["name"]
for status in role_status["taskStatuses"]:
role_status_map[name][status["index"]] = status["attemptStatus"]["podIP"]

log.info("role_status_map is %s", role_status_map)

role_host_port_map = {}
for role_name, status in role_status_map.items():
port = role_cur_port_map[role_name]
ip_ports = []
for i in xrange(len(status)):
ip_ports.append(status[i] + ":" + str(port))
role_host_port_map[role_name] = ",".join(ip_ports)

# generate
for role_name, idx_map in role_status_map.items():
for idx, ip in idx_map.items():
export("PAI_HOST_IP_%s_%d" % (role_name, idx), ip)

# following is legacy, subject to removal
for role_name, host_port_list in role_host_port_map.items():
export("PAI_TASK_ROLE_%s_HOST_LIST" % role_name, host_port_list)

# for role_name, ports in role_ports_map.items():
# ports = ports.split(";")
# for label_port in ports:
# label, port = label_port.split(":")
# for i in xrange(role_task_cnt[role_name]):
# export("PAI_%s_%d_%s_PORT" % (role_name, i, label), port)

if __name__ == '__main__':
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s",
level=logging.INFO)
print("export {}={}".format(k, v))


def generate_runtime_env(framework):
"""Generate runtime env variables for tasks.
# current
PAI_HOST_IP_$taskRole_$taskIndex
PAI_PORT_LIST_$taskRole_$taskIndex_$portType
# backward compatibility
PAI_CURRENT_CONTAINER_IP
PAI_CURRENT_CONTAINER_PORT
PAI_CONTAINER_HOST_IP
PAI_CONTAINER_HOST_PORT
PAI_CONTAINER_SSH_PORT
PAI_CONTAINER_HOST_PORT_LIST
PAI_CONTAINER_HOST_$portType_PORT_LIST
PAI_TASK_ROLE_$taskRole_HOST_LIST
PAI_$taskRole_$taskIndex_$portType_PORT
Args:
framework: Framework object generated by frameworkbarrier.
"""
current_task_index = os.environ.get("FC_TASK_INDEX")
current_taskrole_name = os.environ.get("FC_TASKROLE_NAME")

taskroles = {}
for taskrole in framework["spec"]["taskRoles"]:
taskroles[taskrole["name"]] = {
"number": taskrole["taskNumber"],
"ports": json.loads(taskrole["task"]["pod"]["metadata"]["annotations"]["rest-server/port-scheduling-spec"]),
}
logger.info("task roles: {}".format(taskroles))

for taskrole in framework["status"]["attemptStatus"]["taskRoleStatuses"]:
name = taskrole["name"]
ports = taskroles[name]["ports"]

host_list = []
for task in taskrole["taskStatuses"]:
index = task["index"]
current_ip = task["attemptStatus"]["podHostIP"]

def get_port_base(x):
return int(ports[x]["start"]) + int(ports[x]["count"]) * int(index)

# export ip/port for task role
export("PAI_HOST_IP_{}_{}".format(name, index), current_ip)
for port in ports.keys():
start, count = get_port_base(port), int(ports[port]["count"])
current_port_str = ",".join(str(x) for x in range(start, start + count))
export("PAI_PORT_LIST_{}_{}_{}".format(name, index, port), current_port_str)
export("PAI_{}_{}_{}_PORT".format(name, index, port), current_port_str)

# export ip/port for current container
if (current_taskrole_name == name and current_task_index == index):
export("PAI_CURRENT_CONTAINER_IP", current_ip)
export("PAI_CURRENT_CONTAINER_PORT", get_port_base("http"))
export("PAI_CONTAINER_HOST_IP", current_ip)
export("PAI_CONTAINER_HOST_PORT", get_port_base("http"))
export("PAI_CONTAINER_SSH_PORT", get_port_base("ssh"))
port_str = ""
for port in ports.keys():
start, count = get_port_base(port), int(ports[port]["count"])
current_port_str = ",".join(str(x) for x in range(start, start + count))
export("PAI_CONTAINER_HOST_{}_PORT_LIST".format(port), current_port_str)
port_list += "{}:{};".format(port, current_port_str)
export("PAI_CONTAINER_HOST_PORT_LIST", port_str)

host_list.append("{}:{}".format(current_ip, get_port_base("http")))
export("PAI_TASK_ROLE_{}_HOST_LIST".format(name), ",".join(host_list))


if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s",
level=logging.INFO,
)
parser = argparse.ArgumentParser()
parser.add_argument("framework_json", help="framework.json path generated by frameworkbarrier")
parser.add_argument("framework_json", help="framework.json generated by frameworkbarrier")
args = parser.parse_args()

logger.info("loading json from %s", args.framework_json)
with open(args.framework_json) as f:
framework = json.load(f)
gen_runtime_env(framework)
generate_runtime_env(framework)
49 changes: 29 additions & 20 deletions src/rest-server/src/models/v2/job/k8s.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,21 @@ const convertFrameworkSummary = (framework) => {
};
};

const convertTaskDetail = (taskStatus) => {
const convertTaskDetail = (taskStatus, ports) => {
const containerPorts = {};
if (ports) {
const randomPorts = JSON.parse(ports);
for (let port of Object.keys(randomPorts)) {
containerPorts[port] = randomPorts[port].start + taskStatus.index * randomPorts[port].count;
}
}
const completionStatus = taskStatus.attemptStatus.completionStatus;
return {
taskIndex: taskStatus.index,
taskState: convertState(taskStatus.state, completionStatus ? completionStatus.code : null),
containerId: taskStatus.attemptStatus.podName,
containerIp: taskStatus.attemptStatus.podHostIP,
containerPorts: {}, // TODO
containerPorts,
containerGpus: 0, // TODO
containerLog: '',
containerExitCode: completionStatus ? completionStatus.code : null,
Expand Down Expand Up @@ -154,18 +161,36 @@ const convertFrameworkDetail = (framework) => {
},
taskRoles: {},
};
const ports = {};
for (let taskRoleSpec of framework.spec.taskRoles) {
ports[taskRoleSpec.name] = taskRoleSpec.task.pod.metadata.annotations['rest-server/port-scheduling-spec'];
}
for (let taskRoleStatus of framework.status.attemptStatus.taskRoleStatuses) {
detail.taskRoles[taskRoleStatus.name] = {
taskRoleStatus: {
name: taskRoleStatus.name,
},
taskStatuses: taskRoleStatus.taskStatuses.map(convertTaskDetail),
taskStatuses: taskRoleStatus.taskStatuses.map((status) => convertTaskDetail(status, ports[taskRoleStatus.name])),
};
}
return detail;
};

const generateTaskRole = (taskRole, labels, config) => {
const ports = config.taskRoles[taskRole].resourcePerInstance.ports || {};
for (let port of ['ssh', 'http']) {
if (!(port in ports)) {
ports[port] = 1;
}
}
// schedule ports in [20000, 40000) randomly
const randomPorts = {};
for (let port of Object.keys(ports)) {
randomPorts[port] = {
start: Math.floor((Math.random() * 20000) + 20000),
count: ports[port],
};
}
const frameworkTaskRole = {
name: convertName(taskRole),
taskNumber: config.taskRoles[taskRole].instances || 1,
Expand All @@ -182,6 +207,7 @@ const generateTaskRole = (taskRole, labels, config) => {
},
annotations: {
'container.apparmor.security.beta.kubernetes.io/main': 'unconfined',
'rest-server/port-scheduling-spec': JSON.stringify(randomPorts),
'hivedscheduler.microsoft.com/pod-scheduling-spec': yaml.safeDump(config.taskRoles[taskRole].hivedPodSpec),
},
},
Expand Down Expand Up @@ -239,14 +265,6 @@ const generateTaskRole = (taskRole, labels, config) => {
},
},
},
{
name: 'GPU_ID',
valueFrom: {
fieldRef: {
fieldPath: `metadata.annotations['hivedscheduler.microsoft.com/pod-gpu-isolation']`,
},
},
},
],
securityContext: {
capabilities: {
Expand Down Expand Up @@ -377,15 +395,6 @@ const generateFrameworkDescription = (frameworkName, virtualCluster, config, raw
},
},
},
// use random ports temporally
{
name: 'PAI_CURRENT_CONTAINER_PORT',
value: `${Math.floor((Math.random() * 10000) + 10000)}`,
},
{
name: 'PAI_CONTAINER_SSH_PORT',
value: `${Math.floor((Math.random() * 10000) + 10000)}`,
},
]));
frameworkDescription.spec.taskRoles.push(taskRoleDescription);
}
Expand Down

0 comments on commit 84354bf

Please sign in to comment.