-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0fac1c7
commit b5d0fb2
Showing
6 changed files
with
441 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
class EMA(): | ||
''' | ||
权重滑动平均,对最近的数据给予更高的权重 | ||
uasge: | ||
# 初始化 | ||
ema = EMA(model, 0.999) | ||
ema.register() | ||
# 训练过程中,更新完参数后,同步update shadow weights | ||
def train(): | ||
optimizer.step() | ||
ema.update() | ||
# eval前,apply shadow weights; | ||
# eval之后(保存模型后),恢复原来模型的参数 | ||
def evaluate(): | ||
ema.apply_shadow() | ||
# evaluate | ||
ema.restore() | ||
''' | ||
def __init__(self, model, decay): | ||
self.model = model | ||
self.decay = decay | ||
self.shadow = {} | ||
self.backup = {} | ||
|
||
def register(self): | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad: | ||
self.shadow[name] = param.data.clone() | ||
|
||
def update(self): | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad: | ||
assert name in self.shadow | ||
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] | ||
self.shadow[name] = new_average.clone() | ||
|
||
def apply_shadow(self): | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad: | ||
assert name in self.shadow | ||
self.backup[name] = param.data | ||
param.data = self.shadow[name] | ||
|
||
def restore(self): | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad: | ||
assert name in self.backup | ||
param.data = self.backup[name] | ||
self.backup = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
|
||
class FGM(): | ||
|
||
''' | ||
对于每个x: | ||
1.计算x的前向loss、反向传播得到梯度 | ||
2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r | ||
3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上 | ||
4.将embedding恢复为(1)时的值 | ||
5.根据(3)的梯度对参数进行更新 | ||
uasge: | ||
# 初始化 | ||
fgm = FGM(model) | ||
for batch_input, batch_label in data: | ||
# 正常训练 | ||
loss = model(batch_input, batch_label) | ||
loss.backward() # 反向传播,得到正常的grad | ||
# 对抗训练 | ||
fgm.attack() # 在embedding上添加对抗扰动 | ||
loss_adv = model(batch_input, batch_label) | ||
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 | ||
fgm.restore() # 恢复embedding参数 | ||
# 梯度下降,更新参数 | ||
optimizer.step() | ||
model.zero_grad() | ||
''' | ||
def __init__(self, model): | ||
self.model = model | ||
self.backup = {} | ||
|
||
def attack(self, epsilon=1., emb_name='word_embeddings.weight'): | ||
# emb_name这个参数要换成你模型中embedding的参数名 | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad and emb_name in name: | ||
self.backup[name] = param.data.clone() | ||
norm = torch.norm(param.grad) | ||
if norm != 0: | ||
r_at = epsilon * param.grad / norm | ||
param.data.add_(r_at) | ||
|
||
def restore(self, emb_name='emb.'): | ||
# emb_name这个参数要换成你模型中embedding的参数名 | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad and emb_name in name: | ||
assert name in self.backup | ||
param.data = self.backup[name] | ||
self.backup = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
|
||
|
||
class PGD(): | ||
|
||
''' | ||
对于每个x: | ||
1.计算x的前向loss、反向传播得到梯度并备份 | ||
对于每步t: | ||
2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内) | ||
3.t不是最后一步: 将梯度归0,根据1的x+r计算前后向并得到梯度 | ||
4.t是最后一步: 恢复(1)的梯度,计算最后的x+r并将梯度累加到(1)上 | ||
5.将embedding恢复为(1)时的值 | ||
6.根据(4)的梯度对参数进行更新 | ||
usage: | ||
pgd = PGD(model) | ||
K = 3 | ||
for batch_input, batch_label in data: | ||
# 正常训练 | ||
loss = model(batch_input, batch_label) | ||
loss.backward() # 反向传播,得到正常的grad | ||
pgd.backup_grad() | ||
# 对抗训练 | ||
for t in range(K): | ||
pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data | ||
if t != K-1: | ||
model.zero_grad() | ||
else: | ||
pgd.restore_grad() | ||
loss_adv = model(batch_input, batch_label) | ||
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 | ||
pgd.restore() # 恢复embedding参数 | ||
# 梯度下降,更新参数 | ||
optimizer.step() | ||
model.zero_grad() | ||
''' | ||
|
||
def __init__(self, model): | ||
self.model = model | ||
self.emb_backup = {} | ||
self.grad_backup = {} | ||
|
||
def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings.weight', is_first_attack=False): | ||
# emb_name这个参数要换成你模型中embedding的参数名 | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad and emb_name in name: | ||
if is_first_attack: | ||
self.emb_backup[name] = param.data.clone() | ||
norm = torch.norm(param.grad) | ||
if norm != 0: | ||
r_at = alpha * param.grad / norm | ||
param.data.add_(r_at) | ||
param.data = self.project(name, param.data, epsilon) | ||
|
||
def restore(self, emb_name='word_embeddings.weight'): | ||
# emb_name这个参数要换成你模型中embedding的参数名 | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad and emb_name in name: | ||
assert name in self.emb_backup | ||
param.data = self.emb_backup[name] | ||
self.emb_backup = {} | ||
|
||
def project(self, param_name, param_data, epsilon): | ||
r = param_data - self.emb_backup[param_name] | ||
if torch.norm(r) > epsilon: | ||
r = epsilon * r / torch.norm(r) | ||
return self.emb_backup[param_name] + r | ||
|
||
def backup_grad(self): | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad: | ||
self.grad_backup[name] = param.grad.clone() | ||
|
||
def restore_grad(self): | ||
for name, param in self.model.named_parameters(): | ||
if param.requires_grad: | ||
param.grad = self.grad_backup[name] |
Empty file.
Oops, something went wrong.