Skip to content

Commit

Permalink
[API] use softmax with length, and interleaved matmul for BERT (dmlc#…
Browse files Browse the repository at this point in the history
…1091)

* use softmax with length, and interleaved matmul

* push backward compatibility fix

* fix failing unittests for output_all_encodings, and valid-len=None

* fix lint

* Update bert.py

* amp patch

* Update MXNet 1.6 pre-release version tested on CI

* Update bert.py

Co-authored-by: Leonard Lausen <leonard@lausen.nl>
  • Loading branch information
2 people authored and szhengac committed Jan 24, 2020
1 parent 5fe8d7b commit e88d55e
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 60 deletions.
2 changes: 1 addition & 1 deletion env/cpu/py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- flaky==3.6.1
- flake8==3.7.9
- mock<3
- https://lllausen-data.s3.amazonaws.com/mxnet_cu100-1.6.0b20191231-py2.py3-none-manylinux1_x86_64.whl
- https://lausen-public.s3.us-west-2.amazonaws.com/mxnet_cu100-1.6.0b20200123-py2.py3-none-manylinux1_x86_64.whl
- scipy==1.3.2
- regex==2019.11.1
- nltk==3.4.5
Expand Down
2 changes: 1 addition & 1 deletion env/gpu/py3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies:
- flaky==3.6.1
- flake8==3.7.9
- mock<3
- https://lllausen-data.s3.amazonaws.com/mxnet_cu100-1.6.0b20191231-py2.py3-none-manylinux1_x86_64.whl
- https://lausen-public.s3.us-west-2.amazonaws.com/mxnet_cu100-1.6.0b20200123-py2.py3-none-manylinux1_x86_64.whl
- scipy==1.3.2
- regex==2019.11.1
- nltk==3.4.5
Expand Down
8 changes: 8 additions & 0 deletions scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,16 @@

args = parser.parse_args()


# patch AMP due to issue: https://github.com/apache/incubator-mxnet/issues/17409
ops = ['_contrib_interleaved_matmul_encdec_qk', '_contrib_interleaved_matmul_encdec_valatt',
'_contrib_interleaved_matmul_selfatt_qk', '_contrib_interleaved_matmul_selfatt_valatt']
amp.lists.symbol.WIDEST_TYPE_CASTS.extend(ops)
# end of the patch

log = logging.getLogger()
log.setLevel(logging.INFO)

logging.captureWarnings(True)
fh = logging.FileHandler('log_{0}.txt'.format(args.task_name))
formatter = logging.Formatter(fmt='%(levelname)s:%(name)s:%(asctime)s %(message)s',
Expand Down
Loading

0 comments on commit e88d55e

Please sign in to comment.