-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
26 lines (19 loc) · 776 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Download datasets and train transformer-based models."""
import argparse
import logging
import batching
import datasets
def main():
"""The main entry point to transformers-sota."""
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_directory',
default='/tmp/translation_datasets/',
help='The local directory where the datasets are stored.')
parser.add_argument('--log_level', default='WARNING',
help='The desired logging level.')
args = parser.parse_args()
logging.basicConfig(level=getattr(logging, args.log_level.upper(), None))
data = datasets.obtain(args.dataset_directory)
batches = batching.produce_batches(data['train'], 512)
if __name__ == '__main__':
main()