Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Akash/bugfix #55

Merged
merged 12 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/source/notes/combine_methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ The following describes each supported method and whether or not it requires bot
| gating_on_cat_and_num_feats_then_sum | Gated summation of transformer outputs, numerical feats, and categorical feats before final classifier layer(s). Inspired by [Integrating Multimodal Information in Large Pretrained Transformers](https://www.aclweb.org/anthology/2020.acl-main.214.pdf) which performs the mechanism for each token. | False
| weighted_feature_sum_on_transformer_cat_and_numerical_feats | Learnable weighted feature-wise sum of transformer outputs, numerical feats and categorical feats for each feature dimension before final classifier layer(s) | False

This table shows the the equations involved with each method. First we define some notation
This table shows the the equations involved with each method. First we define some notations:

* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bm%7D)  denotes the combined multimodal features
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bx%7D)  denotes the output text features from the transformer
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bc%7D)  denotes the categorical features
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bn%7D)  denotes the numerical features
* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20h_%7B%5Cmathbf%7B%5CTheta%7D%7D) denotes a MLP parameterized by ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7B%5CTheta%7D)
* ![equation](https://latex.codecogs.com/svg.latex?%5Cmathbf%7BW%7D)  denotes a weight matrix
* ![equation](https://latex.codecogs.com/svg.latex?b)  denotes a scalar bias
* ![m](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bm%7D)   denotes the combined multimodal features
* ![x](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bx%7D)   denotes the output text features from the transformer
* ![c](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bc%7D)   denotes the categorical features
* ![n](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bn%7D)   denotes the numerical features
* ![h_theta](https://latex.codecogs.com/svg.latex?%5Cinline%20h_%7B%5Cmathbf%7B%5CTheta%7D%7D) denotes a MLP parameterized by ![theta](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7B%5CTheta%7D)
* ![W](https://latex.codecogs.com/svg.latex?%5Cmathbf%7BW%7D)   denotes a weight matrix
* ![b](https://latex.codecogs.com/svg.latex?b)   denotes a scalar bias

| Combine Feat Method | Equation |
|:--------------|:-------------------|
Expand Down
1 change: 0 additions & 1 deletion docs/source/notes/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ Say for example we had categorical features of dim 9 and numerical features of d
cat_feat_dim=9, # need to specify this
numerical_feat_dim=5, # need to specify this
num_labels=2, # need to specify this, assuming our task is binary classification
use_num_bn=False,
)

bert_config.tabular_config = tabular_config
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def compute_metrics_fn(p: EvalPrediction):
)
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path
resume_from_checkpoint=model_args.model_name_or_path
if os.path.isdir(model_args.model_name_or_path)
else None
)
Expand Down
312 changes: 165 additions & 147 deletions multimodal_exp_args.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion multimodal_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multimodal_transformers.data
import multimodal_transformers.model

__version__ = "0.2-alpha"
__version__ = "0.3.0"

__all__ = ["multimodal_transformers", "__version__"]
29 changes: 11 additions & 18 deletions multimodal_transformers/model/tabular_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,13 @@ def __init__(self, tabular_config):
self.numerical_feat_dim = tabular_config.numerical_feat_dim
self.num_labels = tabular_config.num_labels
self.numerical_bn = tabular_config.numerical_bn
self.categorical_bn = tabular_config.categorical_bn
self.mlp_act = tabular_config.mlp_act
self.mlp_dropout = tabular_config.mlp_dropout
self.mlp_division = tabular_config.mlp_division
self.text_out_dim = tabular_config.text_feat_dim
self.tabular_config = tabular_config

if self.numerical_bn and self.numerical_feat_dim > 0:
self.num_bn = nn.BatchNorm1d(self.numerical_feat_dim)
else:
self.num_bn = None

if self.combine_feat_method == "text_only":
self.final_out_dim = self.text_out_dim
elif self.combine_feat_method == "concat":
Expand All @@ -131,7 +127,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)
self.final_out_dim = (
self.text_out_dim + output_dim + self.numerical_feat_dim
Expand All @@ -157,7 +153,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn and self.numerical_bn,
)
self.final_out_dim = self.text_out_dim + output_dim
elif (
Expand All @@ -181,7 +177,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)

output_dim_num = 0
Expand All @@ -194,7 +190,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
num_hidden_lyr=1,
return_layer_outs=False,
bn=True,
bn=self.numerical_bn,
)
self.final_out_dim = self.text_out_dim + output_dim_num + output_dim_cat
elif (
Expand All @@ -220,7 +216,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)
else:
self.cat_layer = nn.Linear(self.cat_feat_dim, output_dim_cat)
Expand All @@ -242,7 +238,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.numerical_bn,
)
else:
self.num_layer = nn.Linear(self.numerical_feat_dim, output_dim_num)
Expand Down Expand Up @@ -275,7 +271,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
return_layer_outs=False,
hidden_channels=dims,
bn=True,
bn=self.categorical_bn,
)
else:
output_dim_cat = self.cat_feat_dim
Expand All @@ -297,7 +293,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
return_layer_outs=False,
hidden_channels=dims,
bn=True,
bn=self.numerical_bn,
)
else:
output_dim_num = self.numerical_feat_dim
Expand Down Expand Up @@ -330,7 +326,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.categorical_bn,
)
self.g_cat_layer = nn.Linear(
self.text_out_dim + min(self.text_out_dim, self.cat_feat_dim),
Expand All @@ -357,7 +353,7 @@ def __init__(self, tabular_config):
dropout_prob=self.mlp_dropout,
hidden_channels=dims,
return_layer_outs=False,
bn=True,
bn=self.numerical_bn,
)
self.g_num_layer = nn.Linear(
min(self.numerical_feat_dim, self.text_out_dim) + self.text_out_dim,
Expand Down Expand Up @@ -398,9 +394,6 @@ def forward(self, text_feats, cat_feats=None, numerical_feats=None):
text_feats.device
)

if self.numerical_bn and self.numerical_feat_dim != 0:
numerical_feats = self.num_bn(numerical_feats)

if self.combine_feat_method == "text_only":
combined_feats = text_feats
if self.combine_feat_method == "concat":
Expand Down
3 changes: 3 additions & 0 deletions multimodal_transformers/model/tabular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TabularConfig:
See :obj:`TabularFeatCombiner` for details on the supported methods.
mlp_dropout (float): dropout ratio used for MLP layers
numerical_bn (bool): whether to use batchnorm on numerical features
categorical_bn (bool): whether to use batchnorm on categorical features
use_simple_classifier (bool): whether to use single layer or MLP as final classifier
mlp_act (str): the activation function to use for finetuning layers
gating_beta (float): the beta hyperparameters used for gating tabular data
Expand All @@ -25,6 +26,7 @@ def __init__(
combine_feat_method="text_only",
mlp_dropout=0.1,
numerical_bn=True,
categorical_bn=True,
use_simple_classifier=True,
mlp_act="relu",
gating_beta=0.2,
Expand All @@ -36,6 +38,7 @@ def __init__(
self.combine_feat_method = combine_feat_method
self.mlp_dropout = mlp_dropout
self.numerical_bn = numerical_bn
self.categorical_bn = categorical_bn
self.use_simple_classifier = use_simple_classifier
self.mlp_act = mlp_act
self.gating_beta = gating_beta
Expand Down
2 changes: 2 additions & 0 deletions multimodal_transformers/model/tabular_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def __init__(self, hf_model_config):
self.config.tabular_config = tabular_config.__dict__

tabular_config.text_feat_dim = hf_model_config.hidden_size
tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob
self.tabular_combiner = TabularFeatCombiner(tabular_config)
self.num_labels = tabular_config.num_labels
combined_feat_dim = self.tabular_combiner.final_out_dim
Expand Down Expand Up @@ -603,6 +604,7 @@ def __init__(self, hf_model_config):
self.config.tabular_config = tabular_config.__dict__

tabular_config.text_feat_dim = hf_model_config.hidden_size
tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob
self.tabular_combiner = TabularFeatCombiner(tabular_config)
self.num_labels = tabular_config.num_labels
combined_feat_dim = self.tabular_combiner.final_out_dim
Expand Down
Loading