-
Notifications
You must be signed in to change notification settings - Fork 7
/
attack.py
36 lines (31 loc) · 1.3 KB
/
attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn.functional as F
class AttackPGD():
def __init__(self, net, eps, step_size, num_steps, up, down, random_start=True):
super(AttackPGD, self).__init__()
self.net = net
self.rand = random_start
self.step_size = step_size
self.num_steps = num_steps
self.eps = eps
self.up = up
self.down = down
def find(self, inputs, targets):
requires_grads = [x.requires_grad for x in self.net.parameters()]
self.net.requires_grad_(False)
x = inputs.detach()
if self.rand:
init_noise = torch.zeros_like(x).normal_(0, self.eps / 4)
x = x + torch.clamp(init_noise, -self.eps / 2, self.eps / 2)
x = torch.min(torch.max(x, self.down), self.up)
for i in range(self.num_steps):
x.requires_grad_()
logits = self.net(x)
loss = F.cross_entropy(logits, targets, reduction='sum')
loss.backward()
x = torch.add(x.detach(),torch.sign(x.grad.detach()), alpha=self.step_size)
x = torch.min(torch.max(x, inputs - self.eps), inputs + self.eps)
x = torch.min(torch.max(x, self.down), self.up)
for p, r in zip(self.net.parameters(), requires_grads):
p.requires_grad_(r)
return x