色々なタスクを実行するスクリプトを置くリポジトリです。PyTorch を使用しています。
スクリプトを実行する python script_ner_with_bert.py
テストを実行する pytest --ignore=./tests/TCN --ignore=./tests/trellisnet --ignore=./TCN/
テストを実行する(時間がかかるテストをスキップする) export SKIP_BERT=TRUE ; pytest --ignore=./tests/TCN --ignore=./tests/trellisnet --ignore=./TCN/
環境をインストールする pipenv install --dev
スクリプトを実行する pipenv run python script_ner_with_bert.py
テストを実行する pipenv run pytest --ignore=./tests/TCN --ignore=./tests/trellisnet --ignore=./TCN/
テストを実行する(時間がかかるテストをスキップする) export SKIP_BERT=TRUE ; pipenv run pytest --ignore=./tests/TCN --ignore=./tests/trellisnet --ignore=./TCN/
- ./script_xxx.py
- 各タスクを実行するスクリプトです。
- TrellisNet を利用するスクリプトの場合は、./trellisnet/ を取得してから実行する必要があります。
- ./tests/
- テストですがテストとみせかけて各種モデルの仕様のメモです。
pytest
でテストを実行します。ただし、- ./trellisnet/ を取得していない場合は TrellisNet のテストはできないので
pytest --ignore=./tests/trellisnet
とする必要があります。 - ./TCN/ を取得している場合、このリポジトリ内にファイル名に test が付くファイルがあるので
--ignore=./TCN/
も付ける必要があります。
- ./trellisnet/ を取得していない場合は TrellisNet のテストはできないので
- 環境変数
export SKIP_BERT=TRUE
を設定すると BERT を読み込むテスト(時間がかかる)をスキップできます。
- テストですがテストとみせかけて各種モデルの仕様のメモです。
- ./data/
- 取得したデータを置くところです。
- ./models/
- 自分で定義したモデルを置くところです。
- ./weights/
- 学習したモデルの重みを置くところです。
- ./utils/
- 自分で定義した関数を置くところです。
- ./TCN/
- TCN のソースコードを置くところです。デフォルトで同梱していません。このディレクトリ内で以下のように取得してください。なお、TCN の本家のリポジトリのフォルダ構成ではモデルを読み込めないので、必ず CookieBox26 の Fork を取得してください。
git clone https://github.com/CookieBox26/TCN.git
- TCN のソースコードを置くところです。デフォルトで同梱していません。このディレクトリ内で以下のように取得してください。なお、TCN の本家のリポジトリのフォルダ構成ではモデルを読み込めないので、必ず CookieBox26 の Fork を取得してください。
- ./trellisnet/
- TrellisNet のソースコードを置くところです。デフォルトで同梱していません。このディレクトリ内で以下のように取得してください。なお、TrellisNet の本家のリポジトリのフォルダ構成ではモデルを読み込めないので、必ず CookieBox26 の Fork を取得してください。
git clone https://github.com/CookieBox26/trellisnet.git
- TrellisNet のソースコードを置くところです。デフォルトで同梱していません。このディレクトリ内で以下のように取得してください。なお、TrellisNet の本家のリポジトリのフォルダ構成ではモデルを読み込めないので、必ず CookieBox26 の Fork を取得してください。
MNISTを1次元系列として扱ってどの数字が分類するタスクをGRUまたはTCNで解きます。10エポック学習したパラメータを同梱しています。これを指定して訓練をスキップすると以下のように表示されます。
# GRU
テストデータでの平均損失 0.09641245974451304
テストデータでの正解率 9694/10000 (96.94%)
# TCN
テストデータでの平均損失 0.07634854557925573
テストデータでの正解率 9805/10000 (98.05%)
正のエポック数を指定すると以下のように学習が始まります。
エポック 0
10/938 バッチ (640/60000 サンプル) 流れました 最近 10 バッチの平均損失 2.3024290084838865
20/938 バッチ (1280/60000 サンプル) 流れました 最近 10 バッチの平均損失 2.2999905586242675
30/938 バッチ (1920/60000 サンプル) 流れました 最近 10 バッチの平均損失 2.277213621139526
40/938 バッチ (2560/60000 サンプル) 流れました 最近 10 バッチの平均損失 1.9509142518043519
50/938 バッチ (3200/60000 サンプル) 流れました 最近 10 バッチの平均損失 1.236298155784607
Permuted MNIST も学習できます。
以下補足です。
- 手抜きのためにTCNの著者のコードではなく自分でクラスをかきかえたコード ./models/tcn.py を参照していますが動作は同じです(そのため ./TCN/ を取得しなくても動きます)。
- TCNのネットワーク構造とオプティマイザはTCNの原論文に倣っています。
- GRUのネットワーク構造とオプティマイザは同論文のLSTMのセッティングに似せたものです(Grad Clip はしていません)。
- バッチサイズ64は著者の TCN による Seq. MNIST のコードにあったデフォルト値です。
WNUT’17 の固有表現抽出タスクをしようとしていますが、まだ適当な文章をモデルに流すところまでしか実装されていません。以下が標準出力に出力されるだけです。
# ここに Some weights of the model checkpoint at bert-large-cased were not used... に始まる警告文が出る.
◆ 適当な文章をモデルに流してみる.→ 14トークン×13クラスの予測結果になっている(サイズが).
torch.Size([1, 14, 13])
※ このスクリプトを実行するには ./trellisnet/ の取得が必要です。
文字レベルの Penn Treebank 予測タスクをしようとしていますが、まだ以下が標準出力に出力されるだけです。
訓練データ: 5101618 字
検証データ: 399782 字
テストデータ: 449945 字
ユニーク文字数: 50 字
{' ': 0, '#': 1, '$': 2, '&': 3, "'": 4, '*': 5, '-': 6, '.': 7, '/': 8, '0': 9, '1': 10, '2': 11, '3': 12, '4': 13, '5': 14, '6': 15, '7': 16, '8': 17, '9': 18, '<': 19, '>': 20, 'N': 21, '\\': 22, 'a': 23, 'b': 24, 'c': 25, 'd': 26, 'e': 27, 'f': 28, 'g': 29, 'h': 30, 'i': 31, 'j': 32, 'k': 33, 'l': 34, 'm': 35, 'n': 36, 'o': 37, 'p': 38, 'q': 39, 'r': 40, 's': 41, 't': 42, 'u': 43, 'v': 44, 'w': 45, 'x': 46, 'y': 47, 'z': 48, 'ÿ': 49}
Weight normalization applied
0 バッチ目 tensor(3.9125, grad_fn=<NllLossBackward>)
1 バッチ目 tensor(3.8557, grad_fn=<NllLossBackward>)
2 バッチ目 tensor(3.6744, grad_fn=<NllLossBackward>)
--use_cuda
を付けて実行すると GPU で学習します。
0 バッチ目 tensor(3.9117, device='cuda:0', grad_fn=<NllLossBackward>)
1 バッチ目 tensor(3.8551, device='cuda:0', grad_fn=<NllLossBackward>)
2 バッチ目 tensor(3.4490, device='cuda:0', grad_fn=<NllLossBackward>)
- https://github.com/pytorch/pytorch/tree/v1.6.0
- PyTorch のリポジトリです(v1.6.0)。
- 特に nn.Module のソースは以下です。
- https://github.com/huggingface/transformers/tree/v3.1.0
- transformers のリポジトリです(v3.1.0)。
- 特に BERT モデルのソースは以下です。
- https://github.com/locuslab/TCN
- TCN のリポジトリです。
- 使い勝手のために以下にフォークしています。
- https://github.com/locuslab/trellisnet
- TrellisNet のリポジトリです。
- 使い勝手のために以下にフォークしています。