From b5d0fb201a75e5fb171aef0e4d83cdd845b8f9ec Mon Sep 17 00:00:00 2001 From: rsanshierli <470294527@qq.com> Date: Wed, 20 May 2020 16:46:55 +0800 Subject: [PATCH] optimize --- README.md | 9 ++ optimize/EMA.py | 51 +++++++++ optimize/FGM.py | 48 ++++++++ optimize/PGD.py | 77 +++++++++++++ optimize/__init__.py | 0 requirements.txt | 256 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 441 insertions(+) create mode 100644 optimize/EMA.py create mode 100644 optimize/FGM.py create mode 100644 optimize/PGD.py create mode 100644 optimize/__init__.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index 5eb9899..d2f0d03 100644 --- a/README.md +++ b/README.md @@ -129,3 +129,12 @@ result = main(text) 1. 关系抽取,此部分实验作者正在进行检测。 2. 文本翻译 + + +# 资料参考 + +[https://github.com/lonePatient/BERT-NER-Pytorch](https://github.com/lonePatient/BERT-NER-Pytorch) + +https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch + +https://github.com/zhaogaofeng611/TextMatch \ No newline at end of file diff --git a/optimize/EMA.py b/optimize/EMA.py new file mode 100644 index 0000000..c17401d --- /dev/null +++ b/optimize/EMA.py @@ -0,0 +1,51 @@ +class EMA(): + ''' + 权重滑动平均,对最近的数据给予更高的权重 + uasge: + # 初始化 + ema = EMA(model, 0.999) + ema.register() + + # 训练过程中,更新完参数后,同步update shadow weights + def train(): + optimizer.step() + ema.update() + + # eval前,apply shadow weights; + # eval之后(保存模型后),恢复原来模型的参数 + def evaluate(): + ema.apply_shadow() + # evaluate + ema.restore() + ''' + def __init__(self, model, decay): + self.model = model + self.decay = decay + self.shadow = {} + self.backup = {} + + def register(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.shadow[name] = param.data.clone() + + def update(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + assert name in self.shadow + new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] + self.shadow[name] = new_average.clone() + + def apply_shadow(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + assert name in self.shadow + self.backup[name] = param.data + param.data = self.shadow[name] + + def restore(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + assert name in self.backup + param.data = self.backup[name] + self.backup = {} diff --git a/optimize/FGM.py b/optimize/FGM.py new file mode 100644 index 0000000..d87fddd --- /dev/null +++ b/optimize/FGM.py @@ -0,0 +1,48 @@ +import torch + +class FGM(): + + ''' + 对于每个x: + 1.计算x的前向loss、反向传播得到梯度 + 2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r + 3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上 + 4.将embedding恢复为(1)时的值 + 5.根据(3)的梯度对参数进行更新 + uasge: + # 初始化 + fgm = FGM(model) + for batch_input, batch_label in data: + # 正常训练 + loss = model(batch_input, batch_label) + loss.backward() # 反向传播,得到正常的grad + # 对抗训练 + fgm.attack() # 在embedding上添加对抗扰动 + loss_adv = model(batch_input, batch_label) + loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 + fgm.restore() # 恢复embedding参数 + # 梯度下降,更新参数 + optimizer.step() + model.zero_grad() + ''' + def __init__(self, model): + self.model = model + self.backup = {} + + def attack(self, epsilon=1., emb_name='word_embeddings.weight'): + # emb_name这个参数要换成你模型中embedding的参数名 + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + self.backup[name] = param.data.clone() + norm = torch.norm(param.grad) + if norm != 0: + r_at = epsilon * param.grad / norm + param.data.add_(r_at) + + def restore(self, emb_name='emb.'): + # emb_name这个参数要换成你模型中embedding的参数名 + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + assert name in self.backup + param.data = self.backup[name] + self.backup = {} \ No newline at end of file diff --git a/optimize/PGD.py b/optimize/PGD.py new file mode 100644 index 0000000..240c9fc --- /dev/null +++ b/optimize/PGD.py @@ -0,0 +1,77 @@ +import torch + + +class PGD(): + + ''' + 对于每个x: + 1.计算x的前向loss、反向传播得到梯度并备份 + 对于每步t: + 2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内) + 3.t不是最后一步: 将梯度归0,根据1的x+r计算前后向并得到梯度 + 4.t是最后一步: 恢复(1)的梯度,计算最后的x+r并将梯度累加到(1)上 + 5.将embedding恢复为(1)时的值 + 6.根据(4)的梯度对参数进行更新 + usage: + pgd = PGD(model) + K = 3 + for batch_input, batch_label in data: + # 正常训练 + loss = model(batch_input, batch_label) + loss.backward() # 反向传播,得到正常的grad + pgd.backup_grad() + # 对抗训练 + for t in range(K): + pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data + if t != K-1: + model.zero_grad() + else: + pgd.restore_grad() + loss_adv = model(batch_input, batch_label) + loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 + pgd.restore() # 恢复embedding参数 + # 梯度下降,更新参数 + optimizer.step() + model.zero_grad() + ''' + + def __init__(self, model): + self.model = model + self.emb_backup = {} + self.grad_backup = {} + + def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings.weight', is_first_attack=False): + # emb_name这个参数要换成你模型中embedding的参数名 + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + if is_first_attack: + self.emb_backup[name] = param.data.clone() + norm = torch.norm(param.grad) + if norm != 0: + r_at = alpha * param.grad / norm + param.data.add_(r_at) + param.data = self.project(name, param.data, epsilon) + + def restore(self, emb_name='word_embeddings.weight'): + # emb_name这个参数要换成你模型中embedding的参数名 + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + assert name in self.emb_backup + param.data = self.emb_backup[name] + self.emb_backup = {} + + def project(self, param_name, param_data, epsilon): + r = param_data - self.emb_backup[param_name] + if torch.norm(r) > epsilon: + r = epsilon * r / torch.norm(r) + return self.emb_backup[param_name] + r + + def backup_grad(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + self.grad_backup[name] = param.grad.clone() + + def restore_grad(self): + for name, param in self.model.named_parameters(): + if param.requires_grad: + param.grad = self.grad_backup[name] \ No newline at end of file diff --git a/optimize/__init__.py b/optimize/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..872fdf2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,256 @@ +absl-py==0.9.0 +alabaster==0.7.12 +anaconda-client==1.7.2 +anaconda-navigator==1.9.7 +anaconda-project==0.8.3 +asn1crypto==1.0.1 +astor==0.8.1 +astroid==2.3.1 +astropy==3.2.1 +atomicwrites==1.3.0 +attrs==19.2.0 +Babel==2.7.0 +backcall==0.1.0 +backports.functools-lru-cache==1.5 +backports.os==0.1.1 +backports.shutil-get-terminal-size==1.0.0 +backports.tempfile==1.0 +backports.weakref==1.0.post1 +beautifulsoup4==4.8.0 +bert4keras==0.7.0 +bitarray==1.0.1 +bkcharts==0.2 +bleach==3.1.0 +bokeh==1.3.4 +boto==2.49.0 +boto3==1.12.27 +botocore==1.15.27 +Bottleneck==1.2.1 +cachetools==4.0.0 +certifi==2019.9.11 +cffi==1.12.3 +chardet==3.0.4 +Click==7.0 +cloudpickle==1.2.2 +clyent==1.2.2 +colorama==0.4.1 +comtypes==1.1.7 +conda==4.7.12 +conda-build==3.18.9 +conda-package-handling==1.6.0 +conda-verify==3.4.2 +contextlib2==0.6.0 +cryptography==2.7 +cycler==0.10.0 +Cython==0.29.13 +cytoolz==0.10.0 +dask==2.5.2 +decorator==4.4.0 +defusedxml==0.6.0 +distributed==2.5.2 +docutils==0.15.2 +entrypoints==0.3 +et-xmlfile==1.0.1 +fastcache==1.1.0 +filelock==3.0.12 +Flask==1.1.1 +fsspec==0.5.2 +future==0.17.1 +gast==0.3.3 +gensim==3.8.1 +gevent==1.4.0 +glob2==0.7 +google-api-core==1.16.0 +google-auth==1.12.0 +google-cloud-core==1.3.0 +google-cloud-storage==1.26.0 +google-resumable-media==0.5.0 +googleapis-common-protos==1.51.0 +greenlet==0.4.15 +grpcio==1.27.2 +h5py==2.9.0 +hanziconv==0.3.2 +HeapDict==1.0.1 +html5lib==1.0.1 +idna==2.8 +imageio==2.6.0 +imagesize==1.1.0 +importlib-metadata==0.23 +ipykernel==5.1.2 +ipython==7.8.0 +ipython-genutils==0.2.0 +ipywidgets==7.5.1 +isort==4.3.21 +itsdangerous==1.1.0 +jdcal==1.4.1 +jedi==0.15.1 +jieba==0.42.1 +Jinja2==2.10.3 +jmespath==0.9.5 +joblib==0.13.2 +JPype1==0.7.0 +json5==0.8.5 +jsonschema==3.0.2 +jupyter==1.0.0 +jupyter-client==5.3.3 +jupyter-console==6.0.0 +jupyter-core==4.5.0 +jupyterlab==1.1.4 +jupyterlab-server==1.0.6 +Keras==2.3.1 +Keras-Applications==1.0.8 +Keras-Preprocessing==1.1.0 +keyring==18.0.0 +kiwisolver==1.1.0 +lazy-object-proxy==1.4.2 +libarchive-c==2.8 +llvmlite==0.29.0 +locket==0.2.0 +lxml==4.4.1 +Markdown==3.2.1 +MarkupSafe==1.1.1 +matplotlib==3.1.1 +mccabe==0.6.1 +menuinst==1.4.16 +missingno==0.4.2 +mistune==0.8.4 +mkl-fft==1.0.14 +mkl-random==1.1.0 +mkl-service==2.3.0 +mock==3.0.5 +more-itertools==7.2.0 +mpmath==1.1.0 +msgpack==0.6.1 +multipledispatch==0.6.0 +navigator-updater==0.2.1 +nbconvert==5.6.0 +nbformat==4.4.0 +networkx==2.3 +nltk==3.4.5 +nose==1.3.7 +notebook==6.0.1 +numba==0.45.1 +numexpr==2.7.0 +numpy==1.16.5 +numpydoc==0.9.1 +olefile==0.46 +openpyxl==3.0.0 +packaging==19.2 +pandas==0.25.1 +pandocfilters==1.4.2 +parso==0.5.1 +partd==1.0.0 +path.py==12.0.1 +pathlib2==2.3.5 +patsy==0.5.1 +pep8==1.7.1 +pickleshare==0.7.5 +Pillow==6.2.0 +pkginfo==1.5.0.1 +pluggy==0.13.0 +ply==3.11 +prometheus-client==0.7.1 +prompt-toolkit==2.0.10 +protobuf==3.11.3 +psutil==5.6.3 +py==1.8.0 +pyahocorasick==1.4.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.5.0 +pycosat==0.6.3 +pycparser==2.19 +pycrypto==2.6.1 +pycurl==7.43.0.3 +pyflakes==2.1.1 +Pygments==2.4.2 +pyhanlp==0.1.63 +pylint==2.4.2 +pyodbc==4.0.27 +pyOpenSSL==19.0.0 +pyparsing==2.4.2 +pyreadline==2.1 +pyrsistent==0.15.4 +PySocks==1.7.1 +pytest==5.2.1 +pytest-arraydiff==0.3 +pytest-astropy==0.5.0 +pytest-doctestplus==0.4.0 +pytest-openfiles==0.4.0 +pytest-remotedata==0.3.2 +python-dateutil==2.8.0 +pytz==2019.3 +PyWavelets==1.0.3 +pywin32==223 +pywinpty==0.5.5 +PyYAML==5.1.2 +pyzmq==18.1.0 +QtAwesome==0.6.0 +qtconsole==4.5.5 +QtPy==1.9.0 +regex==2020.2.20 +requests==2.22.0 +rope==0.14.0 +rouge==1.0.0 +rsa==4.0 +ruamel-yaml==0.15.46 +s3transfer==0.3.3 +sacremoses==0.0.41 +scikit-image==0.15.0 +scikit-learn==0.21.3 +scipy==1.3.0 +seaborn==0.9.0 +Send2Trash==1.5.0 +sentencepiece==0.1.85 +simplegeneric==0.8.1 +singledispatch==3.4.0.3 +six==1.12.0 +smart-open==1.10.0 +snowballstemmer==2.0.0 +sortedcollections==1.1.2 +sortedcontainers==2.1.0 +soupsieve==1.9.3 +Sphinx==2.2.0 +sphinxcontrib-applehelp==1.0.1 +sphinxcontrib-devhelp==1.0.1 +sphinxcontrib-htmlhelp==1.0.2 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.2 +sphinxcontrib-serializinghtml==1.1.3 +sphinxcontrib-websupport==1.1.2 +spyder==3.3.6 +spyder-kernels==0.5.2 +SQLAlchemy==1.3.9 +statsmodels==0.10.1 +sympy==1.4 +tables==3.5.2 +tblib==1.4.0 +tensorboard==1.13.1 +tensorflow==1.13.2 +tensorflow-estimator==1.13.0 +termcolor==1.1.0 +terminado==0.8.2 +testpath==0.4.2 +tokenizers==0.5.2 +toolz==0.10.0 +torch==1.4.0+cpu +tornado==6.0.3 +tqdm==4.36.1 +traitlets==4.3.3 +transformers==2.8.0 +unicodecsv==0.14.1 +urllib3==1.24.2 +wcwidth==0.1.7 +webencodings==0.5.1 +Werkzeug==0.16.0 +widgetsnbextension==3.5.1 +win-inet-pton==1.1.0 +win-unicode-console==0.5 +wincertstore==0.2 +wrapt==1.11.2 +xlrd==1.2.0 +XlsxWriter==1.2.1 +xlwings==0.15.10 +xlwt==1.3.0 +zict==1.0.0 +zipp==0.6.0