wandb와 pytorch-lightning으로 깔끔하게 구현해보는 RNN 패밀리 👨👩👧👦
RNNCell
/ RNN
/ LSTMCell
/ LSTM
/ BiLSTMCell
/ BiLSTM
cleanrnns/rnns.py
에 깔끔하게 구현해뒀어요 😊
예를 들면?
-
RNNCell
,RNN
-
LSTMCell
,LSTM
-
BiLSTMCell
,BiLSTM
- 🚧
GRUCell
,GRU
🚧
모델만을 종속변인으로 두고 실험하는 것이 가능하도록 설계해뒀어요 📝
예를 들면?
모델 | f1 score (test) | 가중치 | 소요시간 | hidden_size |
하이퍼파라미터 | wandb 로그 |
---|---|---|---|---|---|---|
RNN | 0.8411 | 16.4M | 18m 19s | 512 | 통일 | 학습 / 테스트 |
LSTM | 0.8522 | 16.4M | 20m 18s | 443 | 통일 | 학습 / 테스트 |
BiLSTM | 0.8539 | 16.4M | 36m 12s | 387 | 통일 | 학습 / 테스트 |
동일한 입력에 대한 예측값도 웹 데모 에서 한눈에 비교가 가능해요 |
---|
단계별로 스크립트를 작성해뒀어요 🪜
예를 들면?
-
run_build_nsmc.py
(데이터 구축) -
run_build_tokenizer.py
(토크나이저 구축) -
run_train.py
(모델 훈련) -
run_test.py
(모델 평가) -
run_deploy.py
(모델 배포) - 🚧
run_tune.py
(하이퍼파라미터 튜닝) 🚧
객체지향, 함수지향 프로그래밍을 적재적소에 활용하여 cleanrnns
패키지를 정리해뒀어요 🧹
예를 들면?
datamodules.py
(객체지향 - 학습에 사용할 데이터셋을pl.LightningDataModule
객체로 추상화)datasets.py
(객체지향 - 풀고자하는 문제에 따른 데이터의 형식을torch.utils.data.Dataset
객체로 추상화)fetchers.py
(함수지향 - 데이터를 로드 및 다운로드하는 로직을 함수로 정의)models.py
(객체지향 - 풀고자하는 문제의 형식을pl.LightningModule
객체로 추상화)paths.py
(로컬 경로 정의)pipelines.py
(객체지향 - 예측에 필요한 로직을 하나의 객체로 추상화)preprocess.py
(함수지향 - 데이터 전처리에 필요한 로직을 함수로 정의)rnns.py
(객체지향 - 각 RNN 모델을torch.nn.Module
로 추상화)tensors.py
(함수지향 - 데이터셋 ->torch.Tensor
변환에 필요한 로직을 함수로 정의)
데이터 구축부터 모델 평가까지 진행해볼 수 있는 Colab 노트북을 만들어뒀어요. 실행을 해보면서 궁금한 점이 있다면 이슈를 남겨주세요 😊
git clone https://github.com/eubinecto/the-clean-rnns.git # 프로젝트 클론
cd the-clean-rnns # 루트 디렉토리 설정
pip3 install -r requirements.txt # 의존 라이브러리 설치
wandb login # Weights & Biases 계정 로그인 (회원가입 필요)
python3 run_build_nsmc.py # Naver Sentiment Movie Corpus
fetch_nsmc
로 구축한 데이터를 간편하게 확인해볼 수 있습니다. Korpora에서 제공하는 nsmc는 기본적으로 validation셋이 없지만,
전처리를 통해 구축합니다.
validation셋의 비율과, 랜덤 스플릿 시드는 config.yaml
에서 설정가능합니다.
from cleanrnns.fetchers import fetch_nsmc
train, val, test = fetch_nsmc()
for row in train.data[:10]:
print(row[0], row[1])
python3 run_build_tokenizer.py # BPE
토크나이저도 fetch_tokenizer
로 간편하게 확인해볼 수 있습니다. 본 프로젝트에서는 Byte Pair Encoding 알고리즘으로 구축하며, 어휘의 크기와 스페셜토큰은 config.yaml
에서 설정 가능합니다.
from cleanrnns.fetchers import fetch_tokenizer
tokenizer = fetch_tokenizer()
# 토크나이징
encoding = tokenizer.encode("이 영화 진짜 재미있다")
print(encoding.ids)
print([tokenizer.id_to_token(token_id) for token_id in encoding.ids])
# 스페셜 토큰
print(tokenizer.pad_token)
print(tokenizer.token_to_id(tokenizer.pad_token))
print(tokenizer.unk_token)
print(tokenizer.token_to_id(tokenizer.unk_token))
print(tokenizer.bos_token)
print(tokenizer.token_to_id(tokenizer.bos_token))
print(tokenizer.eos_token)
print(tokenizer.token_to_id(tokenizer.eos_token))
#어휘의 크기
print(tokenizer.get_vocab_size())
데이터셋과 토크나이저를 구축한뒤에는, 텍스트 데이터를 학습에 사용하기 위해 torch.Tensor
객체로 변환해야합니다. 이를 위한 로직은
NSNC
클래스 (LightningDataModule
)에 담겨 있으며, 모델학습을 진행할 때 내부적으로 사용됩니다. 물론 아래와 같이 텐서변환과정을
간단하게 확인해볼수는 있습니다.
import os
from cleanrnns.fetchers import fetch_config
from cleanrnns.datamodules import NSMC
config = fetch_config()['rnn_for_classification']
config['num_workers'] = os.cpu_count()
tokenizer = fetch_tokenizer()
datamodule = NSMC(config, tokenizer)
datamodule.prepare_data() # wandb로부터 구축된 텍스트데이터 다운로드
datamodule.setup() # 텐서로 변환
print("--- A batch from the training set ---")
for batch in datamodule.train_dataloader():
x, y = batch
print(x) # (N, L)
print(x.shape)
print(y) # (N,)
print(y.shape)
break
print("--- A batch from the validation set ---")
for batch in datamodule.val_dataloader():
x, y = batch
print(x) # (N, L)
print(x.shape)
print(y) # (N,)
print(y.shape)
break
python3 run_train.py rnn_for_classification
python3 run_train.py lstm_for_classification
python3 run_train.py bilsm_for_classification
RNN, LSTM, BiLSTM을 구축한 데이터에 학습시킵니다. hidden_size
, max_epochs
, 등의 하이퍼파라미터는 config.yaml
에서 설정가능합니다.
python3 run_test.py rnn_for_classification
python3 run_test.py lstm_for_classification
python3 run_test.py bilstm_for_classification
RNN, LSTM, BiLSTM의 성능을 구축한 테스트셋으로 측정합니다.
streamlit run run_deploy.py
웹에 배포를 원하신다면 Streamlit Cloud 사용을 추천!
-
ClassificationWithAttentionBase
->RNNForClassificationWithAttention
,LSTMForClassificationWithAttention
,BiLSTMForClassificationWithAttention
- seq2seq 지원
- ner 지원