Skip to content

Commit

Permalink
fix(test): update test_script
Browse files Browse the repository at this point in the history
Signed-off-by: weiwee <wbwmat@gmail.com>
  • Loading branch information
sagewe committed Sep 30, 2022
1 parent 2c98c82 commit ef7210d
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions python/test/test_script.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
import sys
from subprocess import Popen

from fate.arch import Backend, Context
from fate.arch.computing.standalone import CSession
from fate.arch.context import Context, disable_inner_logs
from fate.arch.federation.standalone import StandaloneFederation


def host(federation_id, local_party, parties):
def host(federation_id, party, parties):
disable_inner_logs()
computing = CSession()
federation = StandaloneFederation(computing, federation_id, local_party, parties)
federation = StandaloneFederation(computing, federation_id, party, parties)
ctx = Context(
"guest", backend=Backend.STANDALONE, computing=computing, federation=federation
)
Expand All @@ -22,10 +26,10 @@ def host(federation_id, local_party, parties):
print(ctx.tensor.random_tensor((10, 10)))


def guest(federation_id, local_party, parties):
def guest(federation_id, party, parties):
disable_inner_logs()
computing = CSession()
federation = StandaloneFederation(computing, federation_id, local_party, parties)
federation = StandaloneFederation(computing, federation_id, party, parties)
ctx = Context("host", computing=computing, federation=federation)
with ctx.sub_ctx("predict") as sub_ctx:
sub_ctx.log.error("ctx inited")
Expand All @@ -38,16 +42,42 @@ def guest(federation_id, local_party, parties):

if __name__ == "__main__":
import argparse
import json
import tempfile

parser = argparse.ArgumentParser()
parser.add_argument("role")
parser.add_argument("--role", default=None)
parser.add_argument("--path", default=None)
args = parser.parse_args()
federation_id = "federation_id"
guest_party = ("guest", "guest_party_id")
host_party = ("host", "host_party_id")
parties = [guest_party, host_party]

if args.role == "guest":
guest(federation_id, guest_party, parties)
if args.role == "host":
host(federation_id, host_party, parties)
if not args.role:
federation_id = "federation_id"
guest_party = ("guest", "guest_party_id")
host_party = ("host", "host_party_id")
parties = [guest_party, host_party]
with contextlib.ExitStack() as stack:
f = stack.enter_context(tempfile.NamedTemporaryFile(mode="w"))
json.dump(
dict(party=guest_party, parties=parties, federation_id=federation_id), f
)
f.flush()
p1 = Popen([sys.executable, __file__, "--role", "guest", "--path", f.name])
f = stack.enter_context(tempfile.NamedTemporaryFile(mode="w"))
json.dump(
dict(party=host_party, parties=parties, federation_id=federation_id), f
)
f.flush()
p2 = Popen([sys.executable, __file__, "--role", "host", "--path", f.name])
p1.communicate()
p2.communicate()
elif args.role == "host":
with open(args.path) as f:
config = json.load(f)
config["party"] = tuple(config["party"])
config["parties"] = [tuple(p) for p in config["parties"]]
host(**config)
elif args.role == "guest":
with open(args.path) as f:
config = json.load(f)
config["party"] = tuple(config["party"])
config["parties"] = [tuple(p) for p in config["parties"]]
guest(**config)

0 comments on commit ef7210d

Please sign in to comment.