Skip to content

Commit

Permalink
Extend config with task specific configs. (#3433)
Browse files Browse the repository at this point in the history
* add new default configs

* change prefix default to None
  • Loading branch information
patrickvonplaten authored Mar 25, 2020
1 parent 83272a3 commit ffa17fe
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
13 changes: 10 additions & 3 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ def __init__(self, **kwargs):
self.top_k = kwargs.pop("top_k", 50)
self.top_p = kwargs.pop("top_p", 1.0)
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.length_penalty = kwargs.pop("length_penalty", 1.0)
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand All @@ -94,6 +91,16 @@ def __init__(self, **kwargs):
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())

# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)

# Additional attributes without default values
for key, value in kwargs.items():
try:
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,9 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)

if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
Expand All @@ -635,9 +637,6 @@ def generate(
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictely positive."
assert (
isinstance(num_return_sequences, int) and num_return_sequences > 0
Expand Down Expand Up @@ -708,8 +707,12 @@ def generate(
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)

if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id

assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

Expand Down
14 changes: 9 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,9 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)

if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
Expand All @@ -831,9 +833,6 @@ def generate(
assert pad_token_id is None or (
isinstance(pad_token_id, int) and (pad_token_id >= 0)
), "`pad_token_id` should be a positive integer."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert (eos_token_id is None) or (
isinstance(eos_token_id, int) and (eos_token_id >= 0)
), "`eos_token_id` should be a positive integer."
Expand Down Expand Up @@ -912,7 +911,12 @@ def generate(
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)

if self.config.is_encoder_decoder:
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id

assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

Expand Down

0 comments on commit ffa17fe

Please sign in to comment.