Skip to content

Commit

Permalink
Merge pull request #150 from photonshi/malicious_attack_inversion
Browse files Browse the repository at this point in the history
Malicious attack inversion
  • Loading branch information
photonshi authored Dec 2, 2024
2 parents cd614bb + 42019af commit 5a78e1c
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 3 deletions.
1 change: 0 additions & 1 deletion src/configs/algo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,5 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st
malicious_traditional_model_update_attack,
]


default_config_list: List[ConfigType] = [traditional_fl]
# default_config_list: List[ConfigType] = [fedstatic, fedstatic, fedstatic, fedstatic]
Empty file added src/inversefed/data/__init__.py
Empty file.
Empty file added src/inversefed/data/data.py
Empty file.
209 changes: 209 additions & 0 deletions src/inversefed/data/data_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Repeatable code parts concerning data loading."""


import torch
import torchvision
import torchvision.transforms as transforms

import os

from ..consts import *

from .data import _build_bsds_sr, _build_bsds_dn
from .loss import Classification, PSNR


def construct_dataloaders(dataset, defs, data_path='~/data', shuffle=True, normalize=True):
"""Return a dataloader with given dataset and augmentation, normalize data?."""
path = os.path.expanduser(data_path)

if dataset == 'CIFAR10':
trainset, validset = _build_cifar10(path, defs.augmentations, normalize)
loss_fn = Classification()
elif dataset == 'CIFAR100':
trainset, validset = _build_cifar100(path, defs.augmentations, normalize)
loss_fn = Classification()
elif dataset == 'MNIST':
trainset, validset = _build_mnist(path, defs.augmentations, normalize)
loss_fn = Classification()
elif dataset == 'MNIST_GRAY':
trainset, validset = _build_mnist_gray(path, defs.augmentations, normalize)
loss_fn = Classification()
elif dataset == 'ImageNet':
trainset, validset = _build_imagenet(path, defs.augmentations, normalize)
loss_fn = Classification()
elif dataset == 'BSDS-SR':
trainset, validset = _build_bsds_sr(path, defs.augmentations, normalize, upscale_factor=3, RGB=True)
loss_fn = PSNR()
elif dataset == 'BSDS-DN':
trainset, validset = _build_bsds_dn(path, defs.augmentations, normalize, noise_level=25 / 255, RGB=False)
loss_fn = PSNR()
elif dataset == 'BSDS-RGB':
trainset, validset = _build_bsds_dn(path, defs.augmentations, normalize, noise_level=25 / 255, RGB=True)
loss_fn = PSNR()

if MULTITHREAD_DATAPROCESSING:
num_workers = min(torch.get_num_threads(), MULTITHREAD_DATAPROCESSING) if torch.get_num_threads() > 1 else 0
else:
num_workers = 0

trainloader = torch.utils.data.DataLoader(trainset, batch_size=min(defs.batch_size, len(trainset)),
shuffle=shuffle, drop_last=True, num_workers=num_workers, pin_memory=PIN_MEMORY)
validloader = torch.utils.data.DataLoader(validset, batch_size=min(defs.batch_size, len(trainset)),
shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=PIN_MEMORY)

return loss_fn, trainloader, validloader


def _build_cifar10(data_path, augmentations=True, normalize=True):
"""Define CIFAR-10 with everything considered."""
# Load data
trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.ToTensor())
validset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transforms.ToTensor())

if cifar10_mean is None:
data_mean, data_std = _get_meanstd(trainset)
else:
data_mean, data_std = cifar10_mean, cifar10_std

# Organize preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)])
if augmentations:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transform])
trainset.transform = transform_train
else:
trainset.transform = transform
validset.transform = transform

return trainset, validset

def _build_cifar100(data_path, augmentations=True, normalize=True):
"""Define CIFAR-100 with everything considered."""
# Load data
trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transforms.ToTensor())
validset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transforms.ToTensor())

if cifar100_mean is None:
data_mean, data_std = _get_meanstd(trainset)
else:
data_mean, data_std = cifar100_mean, cifar100_std

# Organize preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)])
if augmentations:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transform])
trainset.transform = transform_train
else:
trainset.transform = transform
validset.transform = transform

return trainset, validset


def _build_mnist(data_path, augmentations=True, normalize=True):
"""Define MNIST with everything considered."""
# Load data
trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor())
validset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())

if mnist_mean is None:
cc = torch.cat([trainset[i][0].reshape(-1) for i in range(len(trainset))], dim=0)
data_mean = (torch.mean(cc, dim=0).item(),)
data_std = (torch.std(cc, dim=0).item(),)
else:
data_mean, data_std = mnist_mean, mnist_std

# Organize preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)])
if augmentations:
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transform])
trainset.transform = transform_train
else:
trainset.transform = transform
validset.transform = transform

return trainset, validset

def _build_mnist_gray(data_path, augmentations=True, normalize=True):
"""Define MNIST with everything considered."""
# Load data
trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor())
validset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor())

if mnist_mean is None:
cc = torch.cat([trainset[i][0].reshape(-1) for i in range(len(trainset))], dim=0)
data_mean = (torch.mean(cc, dim=0).item(),)
data_std = (torch.std(cc, dim=0).item(),)
else:
data_mean, data_std = mnist_mean, mnist_std

# Organize preprocessing
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)])
if augmentations:
transform_train = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.RandomCrop(28, padding=4),
transforms.RandomHorizontalFlip(),
transform])
trainset.transform = transform_train
else:
trainset.transform = transform
validset.transform = transform

return trainset, validset


def _build_imagenet(data_path, augmentations=True, normalize=True):
"""Define ImageNet with everything considered."""
# Load data
trainset = torchvision.datasets.ImageNet(root=data_path, split='train', transform=transforms.ToTensor())
validset = torchvision.datasets.ImageNet(root=data_path, split='val', transform=transforms.ToTensor())

if imagenet_mean is None:
data_mean, data_std = _get_meanstd(trainset)
else:
data_mean, data_std = imagenet_mean, imagenet_std

# Organize preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x : x)])
if augmentations:
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x : x)])
trainset.transform = transform_train
else:
trainset.transform = transform
validset.transform = transform

return trainset, validset


def _get_meanstd(dataset):
cc = torch.cat([trainset[i][0].reshape(3, -1) for i in range(len(trainset))], dim=1)
data_mean = torch.mean(cc, dim=1).tolist()
data_std = torch.std(cc, dim=1).tolist()
return data_mean, data_std
Empty file added src/inversefed/data/datasets.py
Empty file.
114 changes: 114 additions & 0 deletions src/inversefed/data/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Define various loss functions and bundle them with appropriate metrics."""

import torch
import numpy as np


class Loss:
"""Abstract class, containing necessary methods.
Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model
containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations
of the actual metric that is targeted.
"""

def __init__(self):
"""Init."""
pass

def __call__(self, reference, argmin):
"""Return l(x, y)."""
raise NotImplementedError()
return value, name, format

def metric(self, reference, argmin):
"""The actually sought metric."""
raise NotImplementedError()
return value, name, format


class PSNR(Loss):
"""A classical MSE target.
The minimized criterion is MSE Loss, the actual metric is average PSNR.
"""

def __init__(self):
"""Init with torch MSE."""
self.loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

def __call__(self, x=None, y=None):
"""Return l(x, y)."""
name = 'MSE'
format = '.6f'
if x is None:
return name, format
else:
value = 0.5 * self.loss_fn(x, y)
return value, name, format

def metric(self, x=None, y=None):
"""The actually sought metric."""
name = 'avg PSNR'
format = '.3f'
if x is None:
return name, format
else:
value = self.psnr_compute(x, y)
return value, name, format

@staticmethod
def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0):
"""Standard PSNR."""
def get_psnr(img_in, img_ref):
mse = ((img_in - img_ref)**2).mean()
if mse > 0 and torch.isfinite(mse):
return (10 * torch.log10(factor**2 / mse)).item()
elif not torch.isfinite(mse):
return float('nan')
else:
return float('inf')

if batched:
psnr = get_psnr(img_batch.detach(), ref_batch)
else:
[B, C, m, n] = img_batch.shape
psnrs = []
for sample in range(B):
psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
psnr = np.mean(psnrs)

return psnr


class Classification(Loss):
"""A classical NLL loss for classification. Evaluation has the softmax baked in.
The minimized criterion is cross entropy, the actual metric is total accuracy.
"""

def __init__(self):
"""Init with torch MSE."""
self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean')

def __call__(self, x=None, y=None):
"""Return l(x, y)."""
name = 'CrossEntropy'
format = '1.5f'
if x is None:
return name, format
else:
value = self.loss_fn(x, y)
return value, name, format

def metric(self, x=None, y=None):
"""The actually sought metric."""
name = 'Accuracy'
format = '6.2%'
if x is None:
return name, format
else:
value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0]
return value.detach(), name, format
3 changes: 1 addition & 2 deletions src/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import yolo
from utils.types import ConfigType

from inversefed.reconstruction_algorithms import loss_steps

class ModelUtils:
def __init__(self, device: torch.device, config: ConfigType) -> None:
self.device = device
Expand Down Expand Up @@ -197,6 +195,7 @@ def train_classification(
print("here, applying softmax")
output = nn.functional.log_softmax(output, dim=1) # type: ignore
if kwargs.get("gia", False):
from inversefed.reconstruction_algorithms import loss_steps
# Sum the loss and create gradient graph like in loss_steps
# Use modified loss_steps function that returns loss
model.eval()
Expand Down

0 comments on commit 5a78e1c

Please sign in to comment.