Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Function optimization] Add test for autoconverter of Bert/GPT models #4537

Merged
merged 3 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion paddlenlp/transformers/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,15 @@ def _get_name_mappings(cls, config: BertConfig) -> list[StateDictNameMapping]:

# downstream mappings
if "BertForQuestionAnswering" in config.architectures:
model_mappings.extend([["qa_outputs.weight", "classifier.weight"], ["qa_outputs.bias", "classifier.bias"]])
model_mappings.extend(
[["qa_outputs.weight", "classifier.weight", "transpose"], ["qa_outputs.bias", "classifier.bias"]]
)
if (
"BertForMultipleChoice" in config.architectures
or "BertForSequenceClassification" in config.architectures
or "BertForTokenClassification" in config.architectures
):
model_mappings.extend([["classifier.weight", "classifier.weight", "transpose"]])

mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
return mappings
Expand Down
6 changes: 5 additions & 1 deletion paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,15 @@ def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]:

model_mappings.extend(layer_mappings)

# downstream mappings
if "GPT2Model" not in config.architectures:
for mapping in model_mappings:
mapping[0] = "transformer." + mapping[0]
mapping[1] = "gpt." + mapping[1]

if "GPT2ForTokenClassification" in config.architectures:
model_mappings.extend([["classifier.weight", "classifier.weight", "transpose"]])
if "GPT2ForSequenceClassification" in config.architectures:
model_mappings.extend([["score.weight", "score.weight", "transpose"]])
if "GPT2LMHeadModel" in config.architectures:
model_mappings.append(["lm_head.weight", "lm_head.decoder_weight"])

Expand Down
69 changes: 68 additions & 1 deletion tests/transformers/bert/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from __future__ import annotations

import os
import random
import tempfile
import unittest
from typing import List

import numpy as np
import paddle
from parameterized import parameterized_class
from parameterized import parameterized, parameterized_class

from paddlenlp import __version__ as current_version
from paddlenlp.transformers import (
Expand Down Expand Up @@ -442,6 +443,18 @@ def test_params_compatibility_of_init_method(self):


class BertCompatibilityTest(unittest.TestCase):
test_model_id = "hf-internal-testing/tiny-random-BertModel"

@classmethod
@require_package("transformers", "torch")
def setUpClass(cls) -> None:
from transformers import BertModel

# when python application is done, `TemporaryDirectory` will be free
cls.torch_model_path = tempfile.TemporaryDirectory().name
model = BertModel.from_pretrained(cls.test_model_id)
model.save_pretrained(cls.torch_model_path)

def test_model_config_mapping(self):
config = BertConfig(num_labels=22, hidden_dropout_prob=0.99)
self.assertEqual(config.hidden_dropout_prob, 0.99)
Expand Down Expand Up @@ -661,6 +674,60 @@ def test_bert_converter_from_local_dir(self):
np.allclose(paddle_logit.detach().cpu().numpy(), torch_logit.detach().cpu().numpy(), rtol=1e-4)
)

@parameterized.expand(
[
("BertModel",),
# ("BertForMaskedLM",), TODO: need to tie weights
# ("BertForPretraining", "BertForPreTraining"), TODO: need to tie weights
("BertForMultipleChoice",),
("BertForQuestionAnswering",),
("BertForSequenceClassification",),
("BertForTokenClassification",),
]
)
@require_package("transformers", "torch")
def test_bert_classes_from_local_dir(self, class_name, pytorch_class_name: str | None = None):
pytorch_class_name = pytorch_class_name or class_name
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
import torch
import transformers

torch_model_class = getattr(transformers, pytorch_class_name)
torch_model = torch_model_class.from_pretrained(self.torch_model_path)
torch_model.eval()

if "MultipleChoice" in class_name:
# construct input for MultipleChoice Model
torch_model.config.num_choices = random.randint(2, 10)
input_ids = (
paddle.to_tensor(input_ids)
.unsqueeze(1)
.expand([-1, torch_model.config.num_choices, -1])
.cpu()
.numpy()
)

torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 3. forward the paddle model
from paddlenlp import transformers

paddle_model_class = getattr(transformers, class_name)
paddle_model = paddle_model_class.from_pretrained(tempdir)
paddle_model.eval()

paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0]

self.assertTrue(
np.allclose(paddle_logit.detach().cpu().numpy(), torch_logit.detach().cpu().numpy(), rtol=1e-4)
)


class BertModelIntegrationTest(ModelTesterPretrainedMixin, unittest.TestCase):
base_model_class = BertModel
Expand Down
103 changes: 36 additions & 67 deletions tests/transformers/gpt/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import paddle
from parameterized import parameterized_class
from parameterized import parameterized, parameterized_class

from paddlenlp.transformers import (
GPTConfig,
Expand Down Expand Up @@ -555,34 +555,17 @@ def test_model_from_pretrained(self):


class GPTCompatibilityTest(unittest.TestCase):
@require_package("transformers", "torch")
def test_gpt_converter(self):
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.array([[i for i in range(10)]])

# 2. forward the paddle model
from paddlenlp.transformers import GPTModel
test_model_id = "hf-internal-testing/tiny-random-GPT2Model"

paddle_model = GPTModel.from_pretrained(
"hf-internal-testing/tiny-random-GPT2Model", from_hf_hub=True, cache_dir=tempdir
)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

# 3. forward the torch model
import torch
from transformers import GPT2Model
@classmethod
@require_package("transformers", "torch")
def setUpClass(cls) -> None:
from transformers import GPT2Model

torch_model = GPT2Model.from_pretrained("hf-internal-testing/tiny-random-GPT2Model", cache_dir=tempdir)
torch_model.eval()
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0][0]
self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4
)
)
# when python application is done, `TemporaryDirectory` will be free
cls.torch_model_path = tempfile.TemporaryDirectory().name
model = GPT2Model.from_pretrained(cls.test_model_id)
model.save_pretrained(cls.torch_model_path)

@require_package("transformers", "torch")
def test_gpt_converter_from_local_dir_with_enable_torch(self):
Expand All @@ -591,7 +574,7 @@ def test_gpt_converter_from_local_dir_with_enable_torch(self):
# 2. forward the torch model
from transformers import GPT2Model

torch_model = GPT2Model.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
torch_model = GPT2Model.from_pretrained(self.test_model_id)
torch_model.save_pretrained(tempdir)

# 2. forward the paddle model
Expand All @@ -604,61 +587,47 @@ def test_gpt_converter_from_local_dir_with_enable_torch(self):
self.assertIn("conversion is been disabled" in str(error.exception))
model_utils.ENABLE_TORCH_CHECKPOINT = True

@parameterized.expand(
[
("GPTModel", "GPT2Model"),
("GPTForSequenceClassification", "GPT2ForSequenceClassification"),
("GPTForTokenClassification", "GPT2ForTokenClassification"),
("GPTLMHeadModel", "GPT2LMHeadModel"),
]
)
@require_package("transformers", "torch")
def test_gpt_converter_from_local_dir(self):
def test_gpt_classes_from_local_dir(self, paddle_class_name, pytorch_class_name: str | None = None):
pytorch_class_name = pytorch_class_name or paddle_class_name
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.array([[i for i in range(10)]])
input_ids = np.random.randint(100, 200, [1, 20])

# 2. forward the torch model
# 2. forward the torch model
import torch
from transformers import GPT2Model
import transformers

torch_model = GPT2Model.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
torch_model_class = getattr(transformers, pytorch_class_name)
torch_model = torch_model_class.from_pretrained(self.torch_model_path)
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0][0]

# 2. forward the paddle model
from paddlenlp.transformers import GPTModel

paddle_model = GPTModel.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().numpy()[:4, :4], torch_logit.detach().cpu().numpy()[:4, :4], rtol=1e-4
)
)

@require_package("transformers", "torch")
def test_gpt_for_lm_head(self):
with tempfile.TemporaryDirectory() as tempdir:

# 1. create commmon input
input_ids = np.array([[i for i in range(10)]])

# 2. forward the torch model
import torch
from transformers import GPT2LMHeadModel

torch_model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-GPT2Model")
torch_model.eval()
torch_model.save_pretrained(tempdir)
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0][0]
torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]

# 2. forward the paddle model
from paddlenlp.transformers import GPTLMHeadModel
# 3. forward the paddle model
from paddlenlp import transformers

paddle_model = GPTLMHeadModel.from_pretrained(tempdir)
paddle_model_class = getattr(transformers, paddle_class_name)
paddle_model = paddle_model_class.from_pretrained(tempdir)
paddle_model.eval()
paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]

paddle_logit = paddle_model(paddle.to_tensor(input_ids), return_dict=False)[0]

self.assertTrue(
np.allclose(
paddle_logit.detach().cpu().numpy()[:4, :2], torch_logit.detach().cpu().numpy()[:4, :2], rtol=1e-4
paddle_logit.detach().cpu().numpy().reshape([-1])[:16],
torch_logit.detach().cpu().numpy().reshape([-1])[:16],
rtol=1e-4,
)
)

Expand Down