diff --git a/src/gluonnlp/models/roberta.py b/src/gluonnlp/models/roberta.py index 5a72e13dd6..a7108f57b7 100644 --- a/src/gluonnlp/models/roberta.py +++ b/src/gluonnlp/models/roberta.py @@ -224,6 +224,16 @@ def __init__(self, ) self.encoder.hybridize() + if self.use_pooler: + # Construct pooler + self.pooler = nn.Dense(units=self.units, + in_units=self.units, + flatten=False, + activation=self.pooler_activation, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + prefix='pooler_') + if self.use_mlm: embed_weight = None if untie_weight else \ self.tokens_embed.collect_params('.*weight') @@ -293,7 +303,7 @@ def apply_pooling(self, sequence): Shape (batch_size, units) """ outputs = sequence[:, 0, :] - return outputs + return self.pooler(outputs) @staticmethod def get_cfg(key=None):