Skip to content

Commit

Permalink
Support multiple host names for FLARE server (#3018)
Browse files Browse the repository at this point in the history
* support multiple host names for fl server

* add connect_to check

* fix server side overseer agent

* add server identity to fed_client.json

* fix format

---------

Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com>
Co-authored-by: Isaac Yang <isaacy@nvidia.com>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent b704c98 commit 552bb36
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 216 deletions.
1 change: 1 addition & 0 deletions nvflare/apis/utils/format_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

type_pattern_mapping = {
"server": r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$",
"host_name": r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$",
"overseer": r"^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$",
"sp_end_point": r"^((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9]):[0-9]*:[0-9]*)$",
"client": r"^[A-Za-z0-9-_]+$",
Expand Down
48 changes: 42 additions & 6 deletions nvflare/lighter/impl/cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,13 @@ def _build_write_cert_pair(self, participant, base_name, ctx):
f.write(serialize_cert(cert))
with open(os.path.join(dest_dir, f"{base_name}.key"), "wb") as f:
f.write(serialize_pri_key(pri_key))
if base_name == "client" and (listening_host := participant.props.get("listening_host")):
tmp_participant = Participant("server", listening_host, participant.org)
if base_name == "client" and (listening_host := participant.get_listening_host()):
tmp_participant = Participant(
type="server",
name=participant.name,
org=participant.org,
default_host=listening_host,
)
tmp_pri_key, tmp_cert = self.get_pri_key_cert(tmp_participant)
with open(os.path.join(dest_dir, "server.crt"), "wb") as f:
f.write(serialize_cert(tmp_cert))
Expand Down Expand Up @@ -142,10 +147,20 @@ def get_pri_key_cert(self, participant):
subject = self.get_subject(participant)
subject_org = participant.org
if participant.type == "admin":
role = participant.props.get("role")
role = participant.get_prop("role")
else:
role = None
cert = self._generate_cert(subject, subject_org, self.issuer, self.pri_key, pub_key, role=role)

server = participant if participant.type == "server" else None
cert = self._generate_cert(
subject,
subject_org,
self.issuer,
self.pri_key,
pub_key,
role=role,
server=server,
)
return pri_key, cert

def get_subject(self, participant):
Expand All @@ -157,10 +172,20 @@ def _generate_keys(self):
return pri_key, pub_key

def _generate_cert(
self, subject, subject_org, issuer, signing_pri_key, subject_pub_key, valid_days=360, ca=False, role=None
self,
subject,
subject_org,
issuer,
signing_pri_key,
subject_pub_key,
valid_days=360,
ca=False,
role=None,
server: Participant = None,
):
x509_subject = self._x509_name(subject, subject_org, role)
x509_issuer = self._x509_name(issuer)

builder = (
x509.CertificateBuilder()
.subject_name(x509_subject)
Expand All @@ -174,7 +199,6 @@ def _generate_cert(
+ datetime.timedelta(days=valid_days)
# Sign our certificate with our private key
)
.add_extension(x509.SubjectAlternativeName([x509.DNSName(subject)]), critical=False)
)
if ca:
builder = (
Expand All @@ -188,6 +212,18 @@ def _generate_cert(
)
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=False)
)

if server:
# This is to generate a server cert.
# Use SubjectAlternativeName for all host names
default_host = server.get_default_host()
host_names = server.get_host_names()
sans = [x509.DNSName(default_host)]
if host_names:
for h in host_names:
if h != default_host:
sans.append(x509.DNSName(h))
builder = builder.add_extension(x509.SubjectAlternativeName(sans), critical=False)
return builder.sign(signing_pri_key, hashes.SHA256(), default_backend())

def _x509_name(self, cn_name, org_name=None, role=None):
Expand Down
22 changes: 0 additions & 22 deletions nvflare/lighter/impl/local_cert.py

This file was deleted.

69 changes: 0 additions & 69 deletions nvflare/lighter/impl/local_static_file.py

This file was deleted.

132 changes: 84 additions & 48 deletions nvflare/lighter/impl/static_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import yaml

from nvflare.lighter import utils
from nvflare.lighter.spec import Builder
from nvflare.lighter.spec import Builder, Participant


class StaticFileBuilder(Builder):
Expand Down Expand Up @@ -124,28 +124,18 @@ def _build_server(self, server, ctx):
dest_dir = self.get_kit_dir(server, ctx)
server_0 = config["servers"][0]
server_0["name"] = self.project_name
admin_port = server.props.get("admin_port", 8003)
admin_port = server.get_prop("admin_port", 8003)
ctx["admin_port"] = admin_port
fed_learn_port = server.props.get("fed_learn_port", 8002)
fed_learn_port = server.get_prop("fed_learn_port", 8002)
ctx["fed_learn_port"] = fed_learn_port
ctx["server_name"] = self.get_server_name(server)
server_0["service"]["target"] = f"{self.get_server_name(server)}:{fed_learn_port}"
server_0["service"]["scheme"] = self.scheme
server_0["admin_host"] = self.get_server_name(server)
server_0["admin_port"] = admin_port
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
overseer_agent["args"] = {
"role": "server",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": self.get_server_name(server),
"fl_port": str(fed_learn_port),
"admin_port": str(admin_port),
}
overseer_agent.pop("overseer_exists", None)
config["overseer_agent"] = overseer_agent

self._prepare_overseer_agent(server, config, "server", ctx)

utils._write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t")
replacement_dict = {
"admin_port": admin_port,
Expand Down Expand Up @@ -212,14 +202,15 @@ def _build_server(self, server, ctx):
)

def _build_client(self, client, ctx):
project = ctx["project"]
server = project.get_server()
if not server:
raise ValueError("missing server definition in project")
config = json.loads(self.template["fed_client"])
dest_dir = self.get_kit_dir(client, ctx)
fed_learn_port = ctx.get("fed_learn_port")
server_name = ctx.get("server_name")
# config["servers"][0]["service"]["target"] = f"{server_name}:{fed_learn_port}"
config["servers"][0]["service"]["scheme"] = self.scheme
config["servers"][0]["name"] = self.project_name
# config["enable_byoc"] = client.enable_byoc
config["servers"][0]["identity"] = server.name # the official identity of the server
replacement_dict = {
"client_name": f"{client.subject}",
"config_folder": self.config_folder,
Expand All @@ -228,23 +219,8 @@ def _build_client(self, client, ctx):
"type": "client",
"cln_uid": f"uid={client.subject}",
}
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
overseer_agent["args"] = {
"role": "client",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": client.subject,
}
overseer_agent.pop("overseer_exists", None)
config["overseer_agent"] = overseer_agent
# components = client.props.get("components", [])
# config["components"] = list()
# for comp in components:
# temp_dict = {"id": comp}
# temp_dict.update(components[comp])
# config["components"].append(temp_dict)

self._prepare_overseer_agent(client, config, "client", ctx)

utils._write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t")
if self.docker_image:
Expand Down Expand Up @@ -302,6 +278,76 @@ def _build_client(self, client, ctx):
"t",
)

def _check_host_name(self, host_name: str, server: Participant) -> str:
if host_name == server.get_default_host():
# Use the default host - OK
return ""

available_host_names = server.get_host_names()
if available_host_names and host_name in available_host_names:
# use alternative host name - OK
return ""

return f"unknown host name '{host_name}'"

def _prepare_overseer_agent(self, participant, config, role, ctx):
project = ctx["project"]
server = project.get_server()
if not server:
raise ValueError(f"Missing server definition in project {project.name}")

fl_port = server.get_prop("fed_learn_port", 8002)
admin_port = server.get_prop("admin_port", 8003)

if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
if role == "server":
overseer_agent["args"] = {
"role": role,
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": server.name,
"fl_port": str(fl_port),
"admin_port": str(admin_port),
}
else:
overseer_agent["args"] = {
"role": role,
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": participant.subject,
}
else:
# do not use overseer system
# Dummy overseer agent is used here
if role == "server":
# the server expects the "connect_to" to be the same as its name
# otherwise the host name generated by the dummy agent won't be accepted!
connect_to = server.name
else:
connect_to = participant.get_connect_to()
if connect_to:
err = self._check_host_name(connect_to, server)
if err:
raise ValueError(f"bad connect_to in {participant.subject}: {err}")
else:
# connect_to is not explicitly specified: use the server's name by default
# Note: by doing this dynamically, we guarantee the sp_end_point to be correct, even if the
# project.yaml does not specify the default server host correctly!
connect_to = server.get_default_host()

# change the sp_end_point to use connect_to
agent_args = overseer_agent.get("args")
if agent_args:
sp_end_point = agent_args.get("sp_end_point")
if sp_end_point:
# format of the sp_end_point: server_host_name:fl_port:admin_port
agent_args["sp_end_point"] = f"{connect_to}:{fl_port}:{admin_port}"

overseer_agent.pop("overseer_exists", None)
config["overseer_agent"] = overseer_agent

def _build_admin(self, admin, ctx):
dest_dir = self.get_kit_dir(admin, ctx)
admin_port = ctx.get("admin_port")
Expand Down Expand Up @@ -338,17 +384,7 @@ def _build_admin(self, admin, ctx):
def prepare_admin_config(self, admin, ctx):
config = json.loads(self.template["fed_admin"])
agent_config = dict()
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
if overseer_agent.get("overseer_exists", True):
overseer_agent["args"] = {
"role": "admin",
"overseer_end_point": ctx.get("overseer_end_point", ""),
"project": self.project_name,
"name": admin.subject,
}
overseer_agent.pop("overseer_exists", None)
agent_config["overseer_agent"] = overseer_agent
self._prepare_overseer_agent(admin, agent_config, "admin", ctx)
config["admin"].update(agent_config)
return config

Expand Down
Loading

0 comments on commit 552bb36

Please sign in to comment.