Skip to content

Commit

Permalink
add auc calculation, swarm latest fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tremblerz committed Jun 13, 2024
1 parent 9190bbd commit 02ab4ff
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 33 deletions.
43 changes: 24 additions & 19 deletions src/algos/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def local_train(self):
avg_loss, acc = self.model_utils.train(self.model, self.optim,
self.dloader, self.loss_fn,
self.device)
print("Client {} finished training with loss {}".format(self.node_id, avg_loss))
# print("Client {} finished training with loss {}".format(self.node_id, avg_loss))
return acc
# self.log_utils.logger.log_tb(f"train_loss/client{client_num}", avg_loss, epoch)

Expand Down Expand Up @@ -119,25 +119,21 @@ def single_round(self,self_repr):
def run_protocol(self):
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
train_accs = np.empty((self.num_clients, total_epochs)) # Transpose the shape
num_clients = self.config["num_clients"]
train_accs = np.zeros((num_clients, total_epochs))

for round in range(start_epochs, total_epochs):
#self.log_utils.logging.info("Client waiting for semaphore from {}".format(self.server_node))
#print("Client waiting for semaphore from {}".format(self.server_node))
self.comm_utils.wait_for_signal(src=0, tag=self.tag.START)
#print("semaphore received, start local training")
# self.log_utils.logging.info("Client received semaphore from {}".format(self.server_node))
train_acc = self.local_train()
train_accs[self.node_id-1, round] = train_acc
print("Node {} train_acc:{:.4f}".format(self.node_id, train_acc))
np.save('./train_accs.npy', train_accs)
#self.local_test()
print(train_accs)
self.comm_utils.send_signal(dest=0, data=train_acc, tag=self.tag.FINISH)

self_repr = self.get_representation()
# self.log_utils.logging.info("Client {} sending done signal to {}".format(self.node_id, self.server_node))
#print("sending signal to node {}".format(self.server_node))
if self.node_id == 1:
repr = self.single_round(self_repr)
# self.log_utils.logging.info("Client {} waiting to get new model from {}".format(self.node_id, self.server_node))
#get all representation from server
else:
self.comm_utils.send_signal(dest=self.server_node, data=self_repr, tag=self.tag.DONE)
print("Node {} waiting signal from node 1".format(self.node_id))
Expand All @@ -148,6 +144,9 @@ def run_protocol(self):
print("Node {} test_acc:{:.4f}".format(self.node_id, acc))
self.comm_utils.send_signal(dest=0, data=acc, tag=self.tag.FINISH)

# Save the train_accs array to a text file
np.savetxt('train_accs.txt', train_accs, delimiter=',')

class SWARMServer(BaseServer):
def __init__(self, config) -> None:
super().__init__(config)
Expand All @@ -166,7 +165,7 @@ def send_representations(self, representations):
self.comm_utils.send_signal(client_node,
representations,
self.tag.UPDATES)
self.log_utils.log_console("Server sent {} representations to node {}".format(len(representations),client_node))
# self.log_utils.log_console("Server sent {} representations to node {}".format(len(representations),client_node))
#self.model.load_state_dict(representation)

def test(self) -> float:
Expand All @@ -188,8 +187,8 @@ def single_round(self):
Runs the whole training procedure
"""
for client_node in self.clients:
self.log_utils.log_console("Server sending semaphore from {} to {}".format(self.node_id,
client_node))
# self.log_utils.log_console("Server sending semaphore from {} to {}".format(self.node_id,
# client_node))
self.comm_utils.send_signal(dest=client_node, data=None, tag=self.tag.START)


Expand All @@ -199,17 +198,23 @@ def run_protocol(self):
self.log_utils.log_console("Starting iid clients federated averaging")
start_epochs = self.config.get("start_epochs", 0)
total_epochs = self.config["epochs"]
test_accs = np.zeros((12, 210))
train_accs = np.zeros((12, 210))

for round in range(start_epochs, total_epochs):
self.round = round
self.log_utils.log_console("Starting round {}".format(round))
self.single_round()
train_acc = self.comm_utils.wait_for_all_clients(self.clients, self.tag.FINISH)
test_acc = self.comm_utils.wait_for_all_clients(self.clients, self.tag.FINISH)
self.log_utils.log_console("Round {} done; acc {}".format(round, train_acc))
self.log_utils.log_console("Round {} done; acc {}".format(round, test_acc))

accs = self.comm_utils.wait_for_all_clients(self.clients, self.tag.FINISH)
self.log_utils.log_console("Round {} done; acc {}".format(round,accs))

for i, acc in enumerate(accs):
for i, acc in enumerate(train_acc):
train_accs[i, round] = acc

np.save('./test_accs.npy', train_accs)
for i, acc in enumerate(test_acc):
test_accs[i, round] = acc

np.save('./train_accs.npy', train_accs)
np.save('./test_accs.npy', test_accs)
6 changes: 3 additions & 3 deletions src/configs/non_iid_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def get_fmow_support(num_clients, domains=FMOW_DMN):
"model": "resnet10",
"model_lr": 1e-4,
"batch_size": 16,
"num_teachers": 5,
"num_teachers": 1,

# params for model
"position": 0,
Expand Down Expand Up @@ -617,7 +617,7 @@ def assign_colab(clients):

# Clients selection
"num_clients": swarm_client,
"target_clients": 2,
"target_clients": 3,
"similarity": "CosineSimilarity", #"EuclideanDistance", "CosineSimilarity",
#"community_type": "dataset",
"with_sim_consensus": True,
Expand Down Expand Up @@ -854,7 +854,7 @@ def assign_colab(clients):

# current_config = fedcentral

current_config = swarm
current_config = defkt
# current_config["test_param"] ="community_type"
# current_config["test_values"] = ["dataset", None]

Expand Down
130 changes: 130 additions & 0 deletions src/utils/auc.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"id": "52f2e02a-69de-40fa-8b08-5b7f26f5066a",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "117d9a6b-d855-4d74-b591-0cb4c6e57920",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Shape of train_acc: (12, 210)\n",
"Shape of test_acc: (12, 210)\n",
"Individual AUCs for each client:\n",
"Client 1 - Training AUC: 190.15625, Testing AUC: 55.803525641025644\n",
"Client 2 - Training AUC: 186.25, Testing AUC: 55.803525641025644\n",
"Client 3 - Training AUC: 184.515625, Testing AUC: 55.803525641025644\n",
"Client 4 - Training AUC: 184.03125, Testing AUC: 55.803525641025644\n",
"Client 5 - Training AUC: 180.234375, Testing AUC: 40.629852744310575\n",
"Client 6 - Training AUC: 181.984375, Testing AUC: 40.629852744310575\n",
"Client 7 - Training AUC: 181.09375, Testing AUC: 40.629852744310575\n",
"Client 8 - Training AUC: 177.28125, Testing AUC: 40.629852744310575\n",
"Client 9 - Training AUC: 186.03125, Testing AUC: 60.72624798711756\n",
"Client 10 - Training AUC: 178.59375, Testing AUC: 60.72624798711756\n",
"Client 11 - Training AUC: 187.078125, Testing AUC: 60.72624798711756\n",
"Client 12 - Training AUC: 181.203125, Testing AUC: 60.72624798711756\n",
"\n",
"Summary Statistics for Training AUC:\n",
"Mean: 183.20442708333334, Median: 183.0078125, Standard Deviation: 3.6299101215838543\n",
"\n",
"Summary Statistics for Testing AUC:\n",
"Mean: 52.386542124151255, Median: 55.803525641025644, Standard Deviation: 8.552703576636738\n",
"\n",
"Training Accuracies for each client and round:\n",
"[[0.1875 0.125 0.25 ... 1. 1. 1. ]\n",
" [0.09375 0.1875 0.28125 ... 1. 1. 0.96875]\n",
" [0.15625 0.15625 0.25 ... 1. 1. 1. ]\n",
" ...\n",
" [0.1875 0.15625 0.28125 ... 1. 1. 1. ]\n",
" [0.15625 0.21875 0.34375 ... 1. 1. 1. ]\n",
" [0.1875 0.125 0.3125 ... 0.9375 1. 0.96875]]\n"
]
}
],
"source": [
"train_acc = np.load('./train_accs.npy')\n",
"test_acc = np.load('./test_accs.npy')\n",
"\n",
"print(\"Shape of train_acc:\", train_acc.shape)\n",
"print(\"Shape of test_acc:\", test_acc.shape)\n",
"\n",
"# Ensure the arrays have the expected shape (12 clients, 210 epochs)\n",
"if train_acc.shape[1] != 210 or test_acc.shape[1] != 210:\n",
" raise ValueError(\"The accuracy arrays must have 210 epochs. Ensure shape is (num_clients, 210).\")\n",
"\n",
"train_auc = [np.trapz(train_acc[i, :]) for i in range(train_acc.shape[0])]\n",
"test_auc = [np.trapz(test_acc[i, :]) for i in range(test_acc.shape[0])]\n",
"\n",
"print(\"Individual AUCs for each client:\")\n",
"for i, (t_auc, te_auc) in enumerate(zip(train_auc, test_auc)):\n",
" print(f\"Client {i+1} - Training AUC: {t_auc}, Testing AUC: {te_auc}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b8625b5-ab58-4f90-b099-0c9f7a83356c",
"metadata": {},
"outputs": [],
"source": [
"train_auc_mean = np.mean(train_auc)\n",
"train_auc_median = np.median(train_auc)\n",
"train_auc_std = np.std(train_auc)\n",
"\n",
"test_auc_mean = np.mean(test_auc)\n",
"test_auc_median = np.median(test_auc)\n",
"test_auc_std = np.std(test_auc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd040a6b-7804-4b13-b516-52c3ff1e486c",
"metadata": {},
"outputs": [],
"source": [
"print(\"\\nSummary Statistics for Training AUC:\")\n",
"print(f\"Mean: {train_auc_mean}, Median: {train_auc_median}, Standard Deviation: {train_auc_std}\")\n",
"\n",
"print(\"\\nSummary Statistics for Testing AUC:\")\n",
"print(f\"Mean: {test_auc_mean}, Median: {test_auc_median}, Standard Deviation: {test_auc_std}\")\n",
"\n",
"print(\"\\nTraining Accuracies for each client and round:\")\n",
"print(train_acc)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
28 changes: 17 additions & 11 deletions src/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,42 @@ def get_model(self, model_name:str, dset:str, device:torch.device, device_ids:li
model = model.to(device)
return model

def train(self, model:nn.Module, optim, dloader, loss_fn, device: torch.device, test_loader=None, **kwargs) -> Tuple[float, float]:
"""TODO: generate docstring
"""
def train(self, model: nn.Module, optim, dloader, loss_fn, device: torch.device, test_loader=None, **kwargs) -> Tuple[float, float]:
model.train()
train_loss = 0
correct = 0
total_samples = 0

for batch_idx, (data, target) in enumerate(dloader):
data, target = data.to(device), target.to(device)
optim.zero_grad()
position = kwargs.get("position", 0)
output = model(data, position=position)
if kwargs.get("apply_softmax", False):
output = nn.functional.log_softmax(output, dim=1) # type: ignore
output = nn.functional.log_softmax(output, dim=1)
loss = loss_fn(output, target)
loss.backward()
optim.step()
train_loss += loss.item()

train_loss += loss.item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
# view_as() is used to make sure the shape of pred and target are the same

# Convert multi-class target to correct shape
if len(target.size()) > 1:
target = target.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

correct += pred.eq(target.view_as(pred)).sum().item()
total_samples += data.size(0)

if test_loader is not None:
test_loss, test_acc = self.test(model, test_loader, loss_fn, device)
print(f"Train Loss: {train_loss/(batch_idx+1):.6f} | Train Acc: {correct/((batch_idx+1)*len(data)):.6f} | Test Loss: {test_loss:.6f} | Test Acc: {test_acc:.6f}")

acc = correct / len(dloader.dataset)
return train_loss, acc
print(f"Train Loss: {train_loss/total_samples:.6f} | Train Acc: {correct/total_samples:.6f} | Test Loss:{test_loss:.6f} | Test Acc: {test_acc:.6f}")

avg_loss = train_loss / total_samples
acc = correct / total_samples
return avg_loss, acc


def train_mask(self, model:nn.Module, mask,optim, dloader, loss_fn, device: torch.device, **kwargs) -> Tuple[float, float]:
"""TODO: generate docstring
"""
Expand Down

0 comments on commit 02ab4ff

Please sign in to comment.