diff --git a/src/algos/swarm.py b/src/algos/swarm.py index 39ad0a0..863a8f3 100644 --- a/src/algos/swarm.py +++ b/src/algos/swarm.py @@ -34,10 +34,11 @@ def local_train(self): """ Train the model locally """ - avg_loss = self.model_utils.train(self.model, self.optim, + avg_loss, acc = self.model_utils.train(self.model, self.optim, self.dloader, self.loss_fn, self.device) - # print("Client {} finished training with loss {}".format(self.node_id, avg_loss)) + print("Client {} finished training with loss {}".format(self.node_id, avg_loss)) + return acc # self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch) def local_test(self, **kwargs): @@ -118,7 +119,7 @@ def single_round(self,self_repr): def run_protocol(self): start_epochs = self.config.get("start_epochs", 0) total_epochs = self.config["epochs"] - test_accs = np.empty((self.num_clients, total_epochs)) # Transpose the shape + train_accs = np.empty((self.num_clients, total_epochs)) # Transpose the shape for round in range(start_epochs, total_epochs): #self.log_utils.logging.info("Client waiting for semaphore from {}".format(self.server_node)) @@ -126,7 +127,9 @@ def run_protocol(self): self.comm_utils.wait_for_signal(src=0, tag=self.tag.START) #print("semaphore received, start local training") # self.log_utils.logging.info("Client received semaphore from {}".format(self.server_node)) - self.local_train() + train_acc = self.local_train() + train_accs[self.node_id-1, round] = train_acc + np.save('./train_accs.npy', train_accs) #self.local_test() self_repr = self.get_representation() # self.log_utils.logging.info("Client {} sending done signal to {}".format(self.node_id, self.server_node)) @@ -144,8 +147,6 @@ def run_protocol(self): acc = self.local_test() print("Node {} test_acc:{:.4f}".format(self.node_id, acc)) self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH) - test_accs[self.node_id-1, round] = acc - np.save('./test_accs.npy', test_accs) class SWARMServer(BaseServer): def __init__(self, config) -> None: @@ -211,4 +212,4 @@ def run_protocol(self): for i, acc in enumerate(accs): train_accs[i, round] = acc - np.save('./train_accs.npy', train_accs) \ No newline at end of file + np.save('./test_accs.npy', train_accs) \ No newline at end of file