Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend config with task specific configs. #3433

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Mar 25, 2020

As discussed and proposed by @thomwolf in PR #3413, another step towards a combined tokenizer/model config is this PR. It extends the normal config with the following parameters:

{
....
prefix = "", # generic generation HP
max_length: 100,  
length_penalty: 1.0,
task_specific_params: {
     "summarization": {  # task id (e.g. name of the pipeline?)
          max_length: 140,
          length_penalty: 2.0
      },
     "translation_en_to_de": {
          prefix: "translate English to German: "
          max_length: 160,
          length_penalty: 3.0
      },
    },
}

In terms of hierarchy for a task-specific generation it would go as follows:

  1. Is the parameter provided as an argument to the generate method ? Yes use these. No - go to 2.
  2. Is the parameter provided in the task_specific_params dict ? Yes use these. No - go to 3.
  3. Is the parameter provided in the default config dict? Yes use these. No - go to 4.
  4. Is the parameter provided hard-coded in the model's config file? Yes use these. No - use the very default parameters of PretrainedConfig

These were our arguments in favor of this:

  • This removes a lot of hard coded parameters in pipelines and examples
  • Another step towards a combined tokenizer / model config
  • A lot of weird if-else statements can be saved ("If task is en-de translation then do X" won't be necessary as the en-de specific parameters will override the default ones)

TODO

If you guys are fine with this structure:

@patrickvonplaten patrickvonplaten force-pushed the add_task_specific_params_to_config branch from f234b1e to 128d7e7 Compare March 25, 2020 15:01
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense, giving where we've been going up to now.

I would like to understand what is our philosophy with the growing size of the configuration files; for example the bert-base-cased configuration on S3 looks like this:

{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 28996
}

(which is readable imo) and once it's saved it now looks like this:

{
  "_num_labels": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "model_type": "bert",
  "no_repeat_ngram_size": 0,
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pad_token_id": 0,
  "pruned_heads": {},
  "repetition_penalty": 1.0,
  "temperature": 1.0,
  "top_k": 50,
  "top_p": 1.0,
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 28996
}

(which is less readable), are we planning to keep them growing as the tokenizer and model configurations are merged? I feel like adding all those attributes to the configuration saves an "experiment" more than a "model". Is this something we're aiming for?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than that interrogation, LGTM!

@patrickvonplaten
Copy link
Contributor Author

I think this makes sense, giving where we've been going up to now.

I would like to understand what is our philosophy with the growing size of the configuration files; for example the bert-base-cased configuration on S3 looks like this:

{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 28996
}

(which is readable imo) and once it's saved it now looks like this:

{
  "_num_labels": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "model_type": "bert",
  "no_repeat_ngram_size": 0,
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pad_token_id": 0,
  "pruned_heads": {},
  "repetition_penalty": 1.0,
  "temperature": 1.0,
  "top_k": 50,
  "top_p": 1.0,
  "torchscript": false,
  "type_vocab_size": 2,
  "use_bfloat16": false,
  "vocab_size": 28996
}

(which is less readable), are we planning to keep them growing as the tokenizer and model configurations are merged? I feel like adding all those attributes to the configuration saves an "experiment" more than a "model". Is this something we're aiming for?

Might it be possible to only save parameters that are different from the default config of the corresponding model? This would keep it readable.

@@ -94,6 +91,16 @@ def __init__(self, **kwargs):
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())

# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be None instead of ""?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer "" here because then we can save some if statements and just write:

[config.prefix + text for text in texts]

in pipelines.py for example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then how are you going to not serialize it when calling save_pretrained?

We could do a getter for self.prefix or "" if it's easier syntax-wise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, will change that!

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Mar 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was to only serialize attributes that are different from the default config attributes.
So if BartConfig.form_pretrained('bart-base-uncased') is created, we compare it to BartConfig() before serializing and only save params that are different from BartConfig(). It might be easier to just save all non None parameters, but then we would additionally save all parameters that are not None by default.

@julien-c
Copy link
Member

LGTM and I agree with what @LysandreJik and you just said above. Serialized config.json should be more minimal.

For instance I've always disliked the id2label and label2id being serialized even for models that don't have a classification head.

@patrickvonplaten
Copy link
Contributor Author

After this is merged I can open a new PR that serializes only the non-default values.

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick tweak and ok to merge

self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be kwargs.pop("task_specific_params", {}) so we can directly test its keys?

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Mar 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess it's the same pro/con discussion than the one above about prefix = kwargs.pop("prefix", "") or kwargs.pop("prefix", "") . If it's not None as a default it would at the moment be serialized and saved in all configs that don't have task_specific_params or am I missing something?

@thomwolf
Copy link
Member

I agree with what @LysandreJik and @julien-c says about serializing only non-default values by the way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants