Skip to content

Commit

Permalink
Separate out a new CNNConfig object instead of a SimpleNamespace, the…
Browse files Browse the repository at this point in the history
…n make the load & save code convert all its pieces to types that work with weights_only=True. Will allow the CNNClassifier objects used for sentiment to function with weights_only=True.

Also convert the constituency classifier to use torch.load(weights_only=True), although there are no distributed constituency classifiers yet
  • Loading branch information
AngledLuffa committed Oct 28, 2024
1 parent f22be05 commit 67871f4
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 78 deletions.
77 changes: 41 additions & 36 deletions stanza/models/classifiers/cnn_classifier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
import logging
import math
import os
import random
import re
from types import SimpleNamespace

import numpy as np
import torch
Expand All @@ -12,6 +12,7 @@

import stanza.models.classifiers.data as data
from stanza.models.classifiers.base_classifier import BaseClassifier
from stanza.models.classifiers.config import CNNConfig
from stanza.models.classifiers.data import SentimentDatum
from stanza.models.classifiers.utils import ExtraVectors, ModelType, build_output_layers
from stanza.models.common.bert_embedding import extract_bert_embeddings
Expand Down Expand Up @@ -71,47 +72,46 @@ def __init__(self, pretrain, extra_vocab, labels,
"""
super(CNNClassifier, self).__init__()
self.labels = labels
# existing models don't have the bert_finetune or use_peft arguments
bert_finetune = getattr(args, "bert_finetune", False)
use_peft = getattr(args, "use_peft", False)
bert_finetune = args.bert_finetune
use_peft = args.use_peft
force_bert_saved = force_bert_saved or bert_finetune
logger.debug("bert_finetune %s / force_bert_saved %s", bert_finetune, force_bert_saved)

# this may change when loaded in a new Pipeline, so it's not part of the config
self.peft_name = peft_name

# we build a separate config out of the args so that we can easily save it in torch
self.config = SimpleNamespace(filter_channels = args.filter_channels,
filter_sizes = args.filter_sizes,
fc_shapes = args.fc_shapes,
dropout = args.dropout,
num_classes = len(labels),
wordvec_type = args.wordvec_type,
extra_wordvec_method = args.extra_wordvec_method,
extra_wordvec_dim = args.extra_wordvec_dim,
extra_wordvec_max_norm = args.extra_wordvec_max_norm,
char_lowercase = args.char_lowercase,
charlm_projection = args.charlm_projection,
has_charlm_forward = charmodel_forward is not None,
has_charlm_backward = charmodel_backward is not None,
use_elmo = args.use_elmo,
elmo_projection = args.elmo_projection,
bert_model = args.bert_model,
bert_finetune = bert_finetune,
bert_hidden_layers = getattr(args, 'bert_hidden_layers', None),
force_bert_saved = force_bert_saved,

use_peft = use_peft,
lora_rank = getattr(args, 'lora_rank', None),
lora_alpha = getattr(args, 'lora_alpha', None),
lora_dropout = getattr(args, 'lora_dropout', None),
lora_modules_to_save = getattr(args, 'lora_modules_to_save', None),
lora_target_modules = getattr(args, 'lora_target_modules', None),

bilstm = args.bilstm,
bilstm_hidden_dim = args.bilstm_hidden_dim,
maxpool_width = args.maxpool_width,
model_type = ModelType.CNN)
self.config = CNNConfig(filter_channels = args.filter_channels,
filter_sizes = args.filter_sizes,
fc_shapes = args.fc_shapes,
dropout = args.dropout,
num_classes = len(labels),
wordvec_type = args.wordvec_type,
extra_wordvec_method = args.extra_wordvec_method,
extra_wordvec_dim = args.extra_wordvec_dim,
extra_wordvec_max_norm = args.extra_wordvec_max_norm,
char_lowercase = args.char_lowercase,
charlm_projection = args.charlm_projection,
has_charlm_forward = charmodel_forward is not None,
has_charlm_backward = charmodel_backward is not None,
use_elmo = args.use_elmo,
elmo_projection = args.elmo_projection,
bert_model = args.bert_model,
bert_finetune = bert_finetune,
bert_hidden_layers = args.bert_hidden_layers,
force_bert_saved = force_bert_saved,

use_peft = use_peft,
lora_rank = args.lora_rank,
lora_alpha = args.lora_alpha,
lora_dropout = args.lora_dropout,
lora_modules_to_save = args.lora_modules_to_save,
lora_target_modules = args.lora_target_modules,

bilstm = args.bilstm,
bilstm_hidden_dim = args.bilstm_hidden_dim,
maxpool_width = args.maxpool_width,
model_type = ModelType.CNN)

self.char_lowercase = args.char_lowercase

Expand Down Expand Up @@ -521,9 +521,14 @@ def get_params(self, skip_modules=True):
for k in skipped:
del model_state[k]

config = dataclasses.asdict(self.config)
config['wordvec_type'] = config['wordvec_type'].name
config['extra_wordvec_method'] = config['extra_wordvec_method'].name
config['model_type'] = config['model_type'].name

params = {
'model': model_state,
'config': self.config,
'config': config,
'labels': self.labels,
'extra_vocab': self.extra_vocab,
}
Expand Down
55 changes: 55 additions & 0 deletions stanza/models/classifiers/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import List

# TODO: perhaps put the enums in this file
from stanza.models.classifiers.utils import WVType, ExtraVectors, ModelType

@dataclass
class CNNConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
filter_channels: int | tuple
filter_sizes: tuple
fc_shapes: tuple
dropout: float
num_classes: int
wordvec_type: WVType
extra_wordvec_method: ExtraVectors
extra_wordvec_dim: int
extra_wordvec_max_norm: float
char_lowercase: bool
charlm_projection: int
has_charlm_forward: bool
has_charlm_backward: bool

use_elmo: bool
elmo_projection: int

bert_model: str
bert_finetune: bool
bert_hidden_layers: int
force_bert_saved: bool

use_peft: bool
lora_rank: int
lora_alpha: float
lora_dropout: float
lora_modules_to_save: List
lora_target_modules: List

bilstm: bool
bilstm_hidden_dim: int
maxpool_width: int
model_type: ModelType

@dataclass
class ConstituencyConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
fc_shapes: tuple
dropout: float
num_classes: int

constituency_backprop: bool
constituency_batch_norm: bool
constituency_node_attn: bool
constituency_top_layer: bool
constituency_all_words: bool

model_type: ModelType
25 changes: 15 additions & 10 deletions stanza/models/classifiers/constituency_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A classifier that uses a constituency parser for the base embeddings
"""

import dataclasses
import logging
from types import SimpleNamespace

Expand All @@ -10,6 +11,7 @@
import torch.nn.functional as F

from stanza.models.classifiers.base_classifier import BaseClassifier
from stanza.models.classifiers.config import ConstituencyConfig
from stanza.models.classifiers.data import SentimentDatum
from stanza.models.classifiers.utils import ModelType, build_output_layers

Expand All @@ -23,15 +25,15 @@ def __init__(self, tree_embedding, labels, args):
super(ConstituencyClassifier, self).__init__()
self.labels = labels
# we build a separate config out of the args so that we can easily save it in torch
self.config = SimpleNamespace(fc_shapes = args.fc_shapes,
dropout = args.dropout,
num_classes = len(labels),
constituency_backprop = args.constituency_backprop,
constituency_batch_norm = args.constituency_batch_norm,
constituency_node_attn = args.constituency_node_attn,
constituency_top_layer = args.constituency_top_layer,
constituency_all_words = args.constituency_all_words,
model_type = ModelType.CONSTITUENCY)
self.config = ConstituencyConfig(fc_shapes = args.fc_shapes,
dropout = args.dropout,
num_classes = len(labels),
constituency_backprop = args.constituency_backprop,
constituency_batch_norm = args.constituency_batch_norm,
constituency_node_attn = args.constituency_node_attn,
constituency_top_layer = args.constituency_top_layer,
constituency_all_words = args.constituency_all_words,
model_type = ModelType.CONSTITUENCY)

self.tree_embedding = tree_embedding

Expand Down Expand Up @@ -79,10 +81,13 @@ def get_params(self, skip_modules=True):

tree_embedding = self.tree_embedding.get_params(skip_modules)

config = dataclasses.asdict(self.config)
config['model_type'] = config['model_type'].name

params = {
'model': model_state,
'tree_embedding': tree_embedding,
'config': self.config,
'config': config,
'labels': self.labels,
}
return params
Expand Down
56 changes: 43 additions & 13 deletions stanza/models/classifiers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,22 @@
import os
import torch
import torch.optim as optim
from types import SimpleNamespace

import stanza.models.classifiers.data as data
import stanza.models.classifiers.cnn_classifier as cnn_classifier
import stanza.models.classifiers.constituency_classifier as constituency_classifier
from stanza.models.classifiers.utils import ModelType
from stanza.models.classifiers.config import CNNConfig, ConstituencyConfig
from stanza.models.classifiers.utils import ModelType, WVType, ExtraVectors
from stanza.models.common.foundation_cache import load_bert, load_bert_with_peft, load_charlm, load_pretrain
from stanza.models.common.peft_config import build_peft_wrapper, load_peft_wrapper
from stanza.models.common.pretrain import Pretrain
from stanza.models.common.utils import get_split_optimizer
from stanza.models.constituency.tree_embedding import TreeEmbedding

from pickle import UnpicklingError
import warnings

logger = logging.getLogger('stanza')

class Trainer:
Expand Down Expand Up @@ -69,9 +74,13 @@ def load(filename, args, foundation_cache=None, load_optimizer=False):
else:
raise FileNotFoundError("Cannot find model in {} or in {}".format(filename, os.path.join(args.save_dir, filename)))
try:
# TODO: switch to weights_only=True
# need to convert enums to int first
checkpoint = torch.load(filename, lambda storage, loc: storage)
# TODO: can remove the try/except once the new version is out
#checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except UnpicklingError as e:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=False)
warnings.warn("The saved classifier has an old format using SimpleNamespace and/or Enum instead of a dict to store config. This version of Stanza can support reading both the new and the old formats. Future versions will only allow loading with weights_only=True. Please resave the pretrained classifier using this version ASAP.")
except BaseException:
logger.exception("Cannot load model from {}".format(filename))
raise
Expand All @@ -91,22 +100,42 @@ def load(filename, args, foundation_cache=None, load_optimizer=False):
}
else:
model_params = checkpoint['params']
model_type = model_params['config'].model_type
# TODO: this can be removed once v1.10.0 is out
if isinstance(model_params['config'], SimpleNamespace):
model_params['config'] = vars(model_params['config'])
# TODO: these isinstance can go away after 1.10.0
model_type = model_params['config']['model_type']
if isinstance(model_type, str):
model_type = ModelType[model_type]
model_params['config']['model_type'] = model_type

if model_type == ModelType.CNN:
# TODO: these updates are only necessary during the
# transition to the @dataclass version of the config
# Once those are all saved, it is no longer necessary
# to patch existing models (since they will all be patched)
if 'has_charlm_forward' not in model_params['config']:
model_params['config']['has_charlm_forward'] = args.charlm_forward_file is not None
if 'has_charlm_backward' not in model_params['config']:
model_params['config']['has_charlm_backward'] = args.charlm_backward_file is not None
for argname in ['bert_hidden_layers', 'bert_finetune', 'force_bert_saved', 'use_peft',
'lora_rank', 'lora_alpha', 'lora_dropout', 'lora_modules_to_save', 'lora_target_modules']:
model_params['config'][argname] = model_params['config'].get(argname, None)
# TODO: these isinstance can go away after 1.10.0
if isinstance(model_params['config']['wordvec_type'], str):
model_params['config']['wordvec_type'] = WVType[model_params['config']['wordvec_type']]
if isinstance(model_params['config']['extra_wordvec_method'], str):
model_params['config']['extra_wordvec_method'] = ExtraVectors[model_params['config']['extra_wordvec_method']]
model_params['config'] = CNNConfig(**model_params['config'])

pretrain = Trainer.load_pretrain(args, foundation_cache)
elmo_model = utils.load_elmo(args.elmo_model) if args.use_elmo else None
# TODO: existing models don't have this attribute, so we
# use None as not having a setting. If the setting is
# False, though, we don't load the charlm
# We don't want to pass a charlm to a model which doesn't use one
has_charlm_forward = getattr(model_params['config'], 'has_charlm_forward', None)
if has_charlm_forward != False:

if model_params['config'].has_charlm_forward:
charmodel_forward = load_charlm(args.charlm_forward_file, foundation_cache)
else:
charmodel_forward = None
has_charlm_backward = getattr(model_params['config'], 'has_charlm_backward', None)
if has_charlm_backward != False:
if model_params['config'].has_charlm_backward:
charmodel_backward = load_charlm(args.charlm_backward_file, foundation_cache)
else:
charmodel_backward = None
Expand Down Expand Up @@ -148,6 +177,7 @@ def load(filename, args, foundation_cache=None, load_optimizer=False):
}
# TODO: integrate with peft for the constituency version
tree_embedding = TreeEmbedding.model_from_params(model_params['tree_embedding'], pretrain_args, foundation_cache)
model_params['config'] = ConstituencyConfig(**model_params['config'])
model = constituency_classifier.ConstituencyClassifier(tree_embedding=tree_embedding,
labels=model_params['labels'],
args=model_params['config'])
Expand Down
3 changes: 2 additions & 1 deletion stanza/pipeline/sentiment_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ClassifierProcessor and have "sentiment" be an option.
"""

import dataclasses
import torch

from types import SimpleNamespace
Expand Down Expand Up @@ -62,7 +63,7 @@ def _set_up_model(self, config, pipeline, device):
self._batch_size = config.get('batch_size', SentimentProcessor.DEFAULT_BATCH_SIZE)

def _set_up_final_config(self, config):
loaded_args = vars(self._model.config)
loaded_args = dataclasses.asdict(self._model.config)
loaded_args = {k: v for k, v in loaded_args.items() if not UDProcessor.filter_out_option(k)}
loaded_args.update(config)
self._config = loaded_args
Expand Down
Loading

0 comments on commit 67871f4

Please sign in to comment.