Skip to content

Commit

Permalink
fix test files
Browse files Browse the repository at this point in the history
Signed-off-by: weijingchen <talkingwallace@sohu.com>

Signed-off-by: cwj <talkingwallace@sohu.com>
  • Loading branch information
talkingwallace committed Nov 29, 2023
1 parent 402ceec commit 28e1af5
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def _get_activation(activation):
raise ValueError(f"Unsupported activation: {activation}")



class PassportBlock(nn.Module):

def __init__(self, passport_distribute: Literal['gaussian', 'uniform'], passport_mode: Literal['single', 'multi']):
Expand Down
5 changes: 0 additions & 5 deletions python/fate/ml/nn/test/test_fedpass_alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,24 @@ def set_seed(seed):
set_seed(42)


# 1. 加载数据
train_data = torchvision.datasets.CIFAR10(root='./cifar10',
train=True,
download=True,
transform=torchvision.transforms.ToTensor())

# 2. 为每个数字保存其索引
digit_indices = [[] for _ in range(10)]
for idx, (_, label) in enumerate(train_data):
digit_indices[label].append(idx)

# 3. 从每个数字的索引中随机选择300个样本作为训练集
selected_train_indices = []
for indices in digit_indices:
selected_train_indices.extend(torch.randperm(len(indices))[:500].tolist())

# 4. 从剩下的索引中随机选择100个样本作为验证集
selected_val_indices = []
for indices in digit_indices:
remaining_indices = [idx for idx in indices if idx not in selected_train_indices]
selected_val_indices.extend(torch.randperm(len(remaining_indices))[:100].tolist())

# 5. 使用Subset获取训练集和验证集
subset_train_data = torch.utils.data.Subset(train_data, selected_train_indices)
subset_val_data = torch.utils.data.Subset(train_data, selected_val_indices)

Expand Down
1 change: 0 additions & 1 deletion python/fate/ml/nn/test/test_hetero_nn_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def set_seed(seed):
training_args=args
)
trainer.train()
print('cwj done')
# pred = trainer.predict(dataset)
# # compute auc
# from sklearn.metrics import roc_auc_score
Expand Down

0 comments on commit 28e1af5

Please sign in to comment.