From e334693082200c26617021860dfc8fea72e80929 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Thu, 26 Oct 2023 10:43:52 -0400 Subject: [PATCH 01/12] Fix(.): model_path -> resume_from_checkpoint in Trainer.train() --- main.py | 2 +- tests/test_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 4169636..c2033d1 100644 --- a/main.py +++ b/main.py @@ -200,7 +200,7 @@ def compute_metrics_fn(p: EvalPrediction): ) if training_args.do_train: trainer.train( - model_path=model_args.model_name_or_path + resume_from_checkpoint=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None ) diff --git a/tests/test_model.py b/tests/test_model.py index 9bcade0..2ab2b97 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -170,7 +170,7 @@ def test_model(json_file: str, model_string: str): ) # Train - trainer.train(model_path=model_path) + trainer.train(resume_from_checkpoint=model_path) # Get predictions test_results = trainer.predict(test_dataset=test_dataset) From 405e8faa49de3faef893cdfe8c44ae5d31852516 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Thu, 26 Oct 2023 10:46:13 -0400 Subject: [PATCH 02/12] Fix(docs/source/notes/introduction.rst): Fix #53 --- docs/source/notes/introduction.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/notes/introduction.rst b/docs/source/notes/introduction.rst index f786320..67fa5c7 100644 --- a/docs/source/notes/introduction.rst +++ b/docs/source/notes/introduction.rst @@ -32,7 +32,7 @@ Say for example we had categorical features of dim 9 and numerical features of d cat_feat_dim=9, # need to specify this numerical_feat_dim=5, # need to specify this num_labels=2, # need to specify this, assuming our task is binary classification - use_num_bn=False, + numerical_bn=False, ) bert_config.tabular_config = tabular_config From 8bb753d567969d7697908402f952e00ab42aa19d Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Thu, 26 Oct 2023 10:58:37 -0400 Subject: [PATCH 03/12] Docs(docs/source/notes/combine_methods.md): formatting. --- docs/source/notes/combine_methods.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/notes/combine_methods.md b/docs/source/notes/combine_methods.md index 38348f1..dc84ce4 100644 --- a/docs/source/notes/combine_methods.md +++ b/docs/source/notes/combine_methods.md @@ -18,15 +18,15 @@ The following describes each supported method and whether or not it requires bot | gating_on_cat_and_num_feats_then_sum | Gated summation of transformer outputs, numerical feats, and categorical feats before final classifier layer(s). Inspired by [Integrating Multimodal Information in Large Pretrained Transformers](https://www.aclweb.org/anthology/2020.acl-main.214.pdf) which performs the mechanism for each token. | False | weighted_feature_sum_on_transformer_cat_and_numerical_feats | Learnable weighted feature-wise sum of transformer outputs, numerical feats and categorical feats for each feature dimension before final classifier layer(s) | False -This table shows the the equations involved with each method. First we define some notation +This table shows the the equations involved with each method. First we define some notations: -* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bm%7D)  denotes the combined multimodal features -* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bx%7D)  denotes the output text features from the transformer -* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bc%7D)  denotes the categorical features -* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bn%7D)  denotes the numerical features -* ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20h_%7B%5Cmathbf%7B%5CTheta%7D%7D) denotes a MLP parameterized by ![equation](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7B%5CTheta%7D) -* ![equation](https://latex.codecogs.com/svg.latex?%5Cmathbf%7BW%7D)  denotes a weight matrix -* ![equation](https://latex.codecogs.com/svg.latex?b)  denotes a scalar bias +* ![m](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bm%7D)   denotes the combined multimodal features +* ![x](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bx%7D)   denotes the output text features from the transformer +* ![c](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bc%7D)   denotes the categorical features +* ![n](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7Bn%7D)   denotes the numerical features +* ![h_theta](https://latex.codecogs.com/svg.latex?%5Cinline%20h_%7B%5Cmathbf%7B%5CTheta%7D%7D) denotes a MLP parameterized by ![theta](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cmathbf%7B%5CTheta%7D) +* ![W](https://latex.codecogs.com/svg.latex?%5Cmathbf%7BW%7D)   denotes a weight matrix +* ![b](https://latex.codecogs.com/svg.latex?b)   denotes a scalar bias | Combine Feat Method | Equation | |:--------------|:-------------------| From e0dca14744c2c28193898933c8b66c51c0ff0024 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Mon, 30 Oct 2023 11:08:12 -0400 Subject: [PATCH 04/12] Fix(.): Resolve #54 Created new variable for config - categorical_bn which determines batchnorm use for categorical features. All batchnorm is now done via the MLP class. Numerical BN is no longer done separately. --- docs/source/notes/introduction.rst | 1 - multimodal_exp_args.py | 4 + .../model/tabular_combiner.py | 29 +- .../model/tabular_config.py | 3 + ...former_text_w_tabular_classification.ipynb | 1614 +++++++++-------- notebooks/text_w_tabular_classification.ipynb | 4 + 6 files changed, 831 insertions(+), 824 deletions(-) diff --git a/docs/source/notes/introduction.rst b/docs/source/notes/introduction.rst index 67fa5c7..0d7654f 100644 --- a/docs/source/notes/introduction.rst +++ b/docs/source/notes/introduction.rst @@ -32,7 +32,6 @@ Say for example we had categorical features of dim 9 and numerical features of d cat_feat_dim=9, # need to specify this numerical_feat_dim=5, # need to specify this num_labels=2, # need to specify this, assuming our task is binary classification - numerical_bn=False, ) bert_config.tabular_config = tabular_config diff --git a/multimodal_exp_args.py b/multimodal_exp_args.py index 0d28968..0da050f 100644 --- a/multimodal_exp_args.py +++ b/multimodal_exp_args.py @@ -111,6 +111,10 @@ class MultimodalDataTrainingArguments: metadata={ 'help': 'whether to use batchnorm on numerical features' }) + categorical_bn: bool = field(default=True, + metadata={ + 'help': 'whether to use batchnorm on categorical features' + }) use_simple_classifier: str = field(default=True, metadata={ 'help': 'whether to use single layer or MLP as final classifier' diff --git a/multimodal_transformers/model/tabular_combiner.py b/multimodal_transformers/model/tabular_combiner.py index 770324f..cf56114 100644 --- a/multimodal_transformers/model/tabular_combiner.py +++ b/multimodal_transformers/model/tabular_combiner.py @@ -95,17 +95,13 @@ def __init__(self, tabular_config): self.numerical_feat_dim = tabular_config.numerical_feat_dim self.num_labels = tabular_config.num_labels self.numerical_bn = tabular_config.numerical_bn + self.categorical_bn = tabular_config.categorical_bn self.mlp_act = tabular_config.mlp_act self.mlp_dropout = tabular_config.mlp_dropout self.mlp_division = tabular_config.mlp_division self.text_out_dim = tabular_config.text_feat_dim self.tabular_config = tabular_config - if self.numerical_bn and self.numerical_feat_dim > 0: - self.num_bn = nn.BatchNorm1d(self.numerical_feat_dim) - else: - self.num_bn = None - if self.combine_feat_method == "text_only": self.final_out_dim = self.text_out_dim elif self.combine_feat_method == "concat": @@ -131,7 +127,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.categorical_bn, ) self.final_out_dim = ( self.text_out_dim + output_dim + self.numerical_feat_dim @@ -157,7 +153,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.categorical_bn and self.numerical_bn, ) self.final_out_dim = self.text_out_dim + output_dim elif ( @@ -181,7 +177,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.categorical_bn, ) output_dim_num = 0 @@ -194,7 +190,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, num_hidden_lyr=1, return_layer_outs=False, - bn=True, + bn=self.numerical_bn, ) self.final_out_dim = self.text_out_dim + output_dim_num + output_dim_cat elif ( @@ -220,7 +216,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.categorical_bn, ) else: self.cat_layer = nn.Linear(self.cat_feat_dim, output_dim_cat) @@ -242,7 +238,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.numerical_bn, ) else: self.num_layer = nn.Linear(self.numerical_feat_dim, output_dim_num) @@ -275,7 +271,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, return_layer_outs=False, hidden_channels=dims, - bn=True, + bn=self.categorical_bn, ) else: output_dim_cat = self.cat_feat_dim @@ -297,7 +293,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, return_layer_outs=False, hidden_channels=dims, - bn=True, + bn=self.numerical_bn, ) else: output_dim_num = self.numerical_feat_dim @@ -330,7 +326,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.categorical_bn, ) self.g_cat_layer = nn.Linear( self.text_out_dim + min(self.text_out_dim, self.cat_feat_dim), @@ -357,7 +353,7 @@ def __init__(self, tabular_config): dropout_prob=self.mlp_dropout, hidden_channels=dims, return_layer_outs=False, - bn=True, + bn=self.numerical_bn, ) self.g_num_layer = nn.Linear( min(self.numerical_feat_dim, self.text_out_dim) + self.text_out_dim, @@ -398,9 +394,6 @@ def forward(self, text_feats, cat_feats=None, numerical_feats=None): text_feats.device ) - if self.numerical_bn and self.numerical_feat_dim != 0: - numerical_feats = self.num_bn(numerical_feats) - if self.combine_feat_method == "text_only": combined_feats = text_feats if self.combine_feat_method == "concat": diff --git a/multimodal_transformers/model/tabular_config.py b/multimodal_transformers/model/tabular_config.py index 9b18f5e..eb160c2 100644 --- a/multimodal_transformers/model/tabular_config.py +++ b/multimodal_transformers/model/tabular_config.py @@ -9,6 +9,7 @@ class TabularConfig: See :obj:`TabularFeatCombiner` for details on the supported methods. mlp_dropout (float): dropout ratio used for MLP layers numerical_bn (bool): whether to use batchnorm on numerical features + categorical_bn (bool): whether to use batchnorm on categorical features use_simple_classifier (bool): whether to use single layer or MLP as final classifier mlp_act (str): the activation function to use for finetuning layers gating_beta (float): the beta hyperparameters used for gating tabular data @@ -25,6 +26,7 @@ def __init__( combine_feat_method="text_only", mlp_dropout=0.1, numerical_bn=True, + categorical_bn=True, use_simple_classifier=True, mlp_act="relu", gating_beta=0.2, @@ -36,6 +38,7 @@ def __init__( self.combine_feat_method = combine_feat_method self.mlp_dropout = mlp_dropout self.numerical_bn = numerical_bn + self.categorical_bn = categorical_bn self.use_simple_classifier = use_simple_classifier self.mlp_act = mlp_act self.gating_beta = gating_beta diff --git a/notebooks/longformer_text_w_tabular_classification.ipynb b/notebooks/longformer_text_w_tabular_classification.ipynb index 0f1ebf5..3a389cc 100644 --- a/notebooks/longformer_text_w_tabular_classification.ipynb +++ b/notebooks/longformer_text_w_tabular_classification.ipynb @@ -20,14 +20,7 @@ }, { "cell_type": "code", - "source": [ - "gpu_info = !nvidia-smi\n", - "gpu_info = '\\n'.join(gpu_info)\n", - "if gpu_info.find('failed') >= 0:\n", - " print('Not connected to a GPU')\n", - "else:\n", - " print(gpu_info)" - ], + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -35,11 +28,10 @@ "id": "UWEp8fjltYc5", "outputId": "94adc575-8962-4ddc-ba42-3075056dee3f" }, - "execution_count": 1, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Thu May 18 14:06:49 2023 \n", "+-----------------------------------------------------------------------------+\n", @@ -63,20 +55,19 @@ "+-----------------------------------------------------------------------------+\n" ] } + ], + "source": [ + "gpu_info = !nvidia-smi\n", + "gpu_info = '\\n'.join(gpu_info)\n", + "if gpu_info.find('failed') >= 0:\n", + " print('Not connected to a GPU')\n", + "else:\n", + " print(gpu_info)" ] }, { "cell_type": "code", - "source": [ - "from psutil import virtual_memory\n", - "ram_gb = virtual_memory().total / 1e9\n", - "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n", - "\n", - "if ram_gb < 20:\n", - " print('Not using a high-RAM runtime')\n", - "else:\n", - " print('You are using a high-RAM runtime!')" - ], + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -84,17 +75,26 @@ "id": "GbmTRG_WtaEz", "outputId": "ac0bab98-3729-4dd5-f26a-daa9e580b434" }, - "execution_count": 2, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Your runtime has 89.6 gigabytes of available RAM\n", "\n", "You are using a high-RAM runtime!\n" ] } + ], + "source": [ + "from psutil import virtual_memory\n", + "ram_gb = virtual_memory().total / 1e9\n", + "print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n", + "\n", + "if ram_gb < 20:\n", + " print('Not using a high-RAM runtime')\n", + "else:\n", + " print('You are using a high-RAM runtime!')" ] }, { @@ -110,8 +110,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting git+https://github.com/jtfields/Multimodal-Toolkit-Longformer.git\n", @@ -222,7 +222,6 @@ ] }, { - "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { "pip_warning": { @@ -232,7 +231,8 @@ } } }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -241,9 +241,7 @@ }, { "cell_type": "code", - "source": [ - "!pip install transformers==4.28.0" - ], + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -251,11 +249,10 @@ "id": "eFkmjDkDLmtL", "outputId": "219f0577-5dc7-4c6b-9d5f-5155202a7a7c" }, - "execution_count": 4, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting transformers==4.28.0\n", @@ -284,6 +281,9 @@ "Successfully installed transformers-4.28.0\n" ] } + ], + "source": [ + "!pip install transformers==4.28.0" ] }, { @@ -298,8 +298,8 @@ }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Mounted at /content/drive\n" ] @@ -388,39 +388,16 @@ "cell_type": "code", "execution_count": 8, "metadata": { - "id": "Ql8mJlKAUWX5", "colab": { "base_uri": "https://localhost:8080/", "height": 354 }, + "id": "Ql8mJlKAUWX5", "outputId": "5ebf0b87-84bf-4fce-a8de-3d9de8ab60a2" }, "outputs": [ { - "output_type": "execute_result", "data": { - "text/plain": [ - " Unnamed: 0 Clothing ID Age Title \\\n", - "0 0 767 33 NaN \n", - "1 1 1080 34 NaN \n", - "2 2 1077 60 Some major design flaws \n", - "3 3 1049 50 My favorite buy! \n", - "4 4 847 47 Flattering shirt \n", - "\n", - " Review Text Rating Recommended IND \\\n", - "0 Absolutely wonderful - silky and sexy and comf... 4 1 \n", - "1 Love this dress! it's sooo pretty. i happene... 5 1 \n", - "2 I had such high hopes for this dress and reall... 3 0 \n", - "3 I love, love, love this jumpsuit. it's fun, fl... 5 1 \n", - "4 This shirt is very flattering to all due to th... 5 1 \n", - "\n", - " Positive Feedback Count Division Name Department Name Class Name \n", - "0 0 Initmates Intimate Intimates \n", - "1 4 General Dresses Dresses \n", - "2 0 General Dresses Dresses \n", - "3 0 General Petite Bottoms Pants \n", - "4 6 General Tops Blouses " - ], "text/html": [ "\n", "
\n", @@ -606,10 +583,33 @@ "
\n", " \n", " " + ], + "text/plain": [ + " Unnamed: 0 Clothing ID Age Title \\\n", + "0 0 767 33 NaN \n", + "1 1 1080 34 NaN \n", + "2 2 1077 60 Some major design flaws \n", + "3 3 1049 50 My favorite buy! \n", + "4 4 847 47 Flattering shirt \n", + "\n", + " Review Text Rating Recommended IND \\\n", + "0 Absolutely wonderful - silky and sexy and comf... 4 1 \n", + "1 Love this dress! it's sooo pretty. i happene... 5 1 \n", + "2 I had such high hopes for this dress and reall... 3 0 \n", + "3 I love, love, love this jumpsuit. it's fun, fl... 5 1 \n", + "4 This shirt is very flattering to all due to th... 5 1 \n", + "\n", + " Positive Feedback Count Division Name Department Name Class Name \n", + "0 0 Initmates Intimate Intimates \n", + "1 4 General Dresses Dresses \n", + "2 0 General Dresses Dresses \n", + "3 0 General Petite Bottoms Pants \n", + "4 6 General Tops Blouses " ] }, + "execution_count": 8, "metadata": {}, - "execution_count": 8 + "output_type": "execute_result" } ], "source": [ @@ -619,16 +619,16 @@ }, { "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "GrDqLsGkHBBt" + }, + "outputs": [], "source": [ "# Sample the dataframe to reduce the file size so trainer can run\n", "#data_df = data_df.sample(n=100, random_state=1)\n", "#data_df" - ], - "metadata": { - "id": "GrDqLsGkHBBt" - }, - "execution_count": 9, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -643,17 +643,17 @@ "cell_type": "code", "execution_count": 10, "metadata": { - "id": "kpzF_0erRsuI", "colab": { "base_uri": "https://localhost:8080/", "height": 248 }, + "id": "kpzF_0erRsuI", "outputId": "cbd4abb8-36d4-4825-87a2-1efd95bb63ba" }, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ ":1: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. \n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", @@ -661,21 +661,7 @@ ] }, { - "output_type": "execute_result", "data": { - "text/plain": [ - " Title Review Text \\\n", - "count 19676 22641 \n", - "unique 13993 22634 \n", - "top Love it! Perfect fit and i've gotten so many compliment... \n", - "freq 136 3 \n", - "\n", - " Division Name Department Name Class Name \n", - "count 23472 23472 23472 \n", - "unique 3 6 20 \n", - "top General Tops Dresses \n", - "freq 13850 10468 6319 " - ], "text/html": [ "\n", "
\n", @@ -817,10 +803,24 @@ "
\n", " \n", " " + ], + "text/plain": [ + " Title Review Text \\\n", + "count 19676 22641 \n", + "unique 13993 22634 \n", + "top Love it! Perfect fit and i've gotten so many compliment... \n", + "freq 136 3 \n", + "\n", + " Division Name Department Name Class Name \n", + "count 23472 23472 23472 \n", + "unique 3 6 20 \n", + "top General Tops Dresses \n", + "freq 13850 10468 6319 " ] }, + "execution_count": 10, "metadata": {}, - "execution_count": 10 + "output_type": "execute_result" } ], "source": [ @@ -831,38 +831,16 @@ "cell_type": "code", "execution_count": 11, "metadata": { - "id": "VSFD6Jl0RsuO", "colab": { "base_uri": "https://localhost:8080/", "height": 300 }, + "id": "VSFD6Jl0RsuO", "outputId": "cccbe0b3-986b-4735-b38f-e0802a4090eb" }, "outputs": [ { - "output_type": "execute_result", "data": { - "text/plain": [ - " Unnamed: 0 Clothing ID Age Rating \\\n", - "count 23486.000000 23486.000000 23486.000000 23486.000000 \n", - "mean 11742.500000 918.118709 43.198544 4.196032 \n", - "std 6779.968547 203.298980 12.279544 1.110031 \n", - "min 0.000000 0.000000 18.000000 1.000000 \n", - "25% 5871.250000 861.000000 34.000000 4.000000 \n", - "50% 11742.500000 936.000000 41.000000 5.000000 \n", - "75% 17613.750000 1078.000000 52.000000 5.000000 \n", - "max 23485.000000 1205.000000 99.000000 5.000000 \n", - "\n", - " Recommended IND Positive Feedback Count \n", - "count 23486.000000 23486.000000 \n", - "mean 0.822362 2.535936 \n", - "std 0.382216 5.702202 \n", - "min 0.000000 0.000000 \n", - "25% 1.000000 0.000000 \n", - "50% 1.000000 1.000000 \n", - "75% 1.000000 3.000000 \n", - "max 1.000000 122.000000 " - ], "text/html": [ "\n", "
\n", @@ -1045,10 +1023,32 @@ "
\n", " \n", " " + ], + "text/plain": [ + " Unnamed: 0 Clothing ID Age Rating \\\n", + "count 23486.000000 23486.000000 23486.000000 23486.000000 \n", + "mean 11742.500000 918.118709 43.198544 4.196032 \n", + "std 6779.968547 203.298980 12.279544 1.110031 \n", + "min 0.000000 0.000000 18.000000 1.000000 \n", + "25% 5871.250000 861.000000 34.000000 4.000000 \n", + "50% 11742.500000 936.000000 41.000000 5.000000 \n", + "75% 17613.750000 1078.000000 52.000000 5.000000 \n", + "max 23485.000000 1205.000000 99.000000 5.000000 \n", + "\n", + " Recommended IND Positive Feedback Count \n", + "count 23486.000000 23486.000000 \n", + "mean 0.822362 2.535936 \n", + "std 0.382216 5.702202 \n", + "min 0.000000 0.000000 \n", + "25% 1.000000 0.000000 \n", + "50% 1.000000 1.000000 \n", + "75% 1.000000 3.000000 \n", + "max 1.000000 122.000000 " ] }, + "execution_count": 11, "metadata": {}, - "execution_count": 11 + "output_type": "execute_result" } ], "source": [ @@ -1068,16 +1068,16 @@ "cell_type": "code", "execution_count": 12, "metadata": { - "id": "uQ41PkSZRsts", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "uQ41PkSZRsts", "outputId": "524ca3d8-5d5a-4ccc-8ae6-3e48e076c378" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Num examples train-val-test\n", "18788 2349 2349\n" @@ -1191,6 +1191,10 @@ " metadata={\n", " 'help': 'whether to use batchnorm on numerical features'\n", " })\n", + " categorical_bn: bool = field(default=True,\n", + " metadata={\n", + " 'help': 'whether to use batchnorm on categorical features'\n", + " })\n", " use_simple_classifier: str = field(default=True,\n", " metadata={\n", " 'help': 'whether to use single layer or MLP as final classifier'\n", @@ -1225,25 +1229,25 @@ }, { "cell_type": "code", - "source": [ - "#pip uninstall -y transformers accelerate" - ], + "execution_count": 14, "metadata": { "id": "3KjfxKsxIvXC" }, - "execution_count": 14, - "outputs": [] + "outputs": [], + "source": [ + "#pip uninstall -y transformers accelerate" + ] }, { "cell_type": "code", - "source": [ - "#pip install transformers accelerate" - ], + "execution_count": 15, "metadata": { "id": "w96bJG0dJ2eb" }, - "execution_count": 15, - "outputs": [] + "outputs": [], + "source": [ + "#pip install transformers accelerate" + ] }, { "cell_type": "code", @@ -1312,7 +1316,6 @@ "cell_type": "code", "execution_count": 17, "metadata": { - "id": "38GJb0Y7RsuZ", "colab": { "base_uri": "https://localhost:8080/", "height": 163, @@ -1363,71 +1366,72 @@ "5ad5ad4c3abf4339add1bb374766d962" ] }, + "id": "38GJb0Y7RsuZ", "outputId": "2c4f7507-7e22-4f1d-d207-3b12a357b25e" }, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "Specified tokenizer: allenai/longformer-base-4096\n" ] }, { - "output_type": "display_data", "data": { - "text/plain": [ - "Downloading (…)lve/main/config.json: 0%| | 0.00/694 [00:00" - ], "text/html": [ "\n", "
\n", @@ -2042,27 +2042,31 @@ " \n", " \n", "

" + ], + "text/plain": [ + "" ] }, - "metadata": {} + "metadata": {}, + "output_type": "display_data" }, { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "CPU times: user 9min 48s, sys: 3min 49s, total: 13min 37s\n", "Wall time: 13min 21s\n" ] }, { - "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=588, training_loss=0.21648866992418458, metrics={'train_runtime': 801.7556, 'train_samples_per_second': 23.434, 'train_steps_per_second': 0.733, 'total_flos': 2039261215755456.0, 'train_loss': 0.21648866992418458, 'epoch': 1.0})" ] }, + "execution_count": 34, "metadata": {}, - "execution_count": 34 + "output_type": "execute_result" } ], "source": [ @@ -2106,9 +2110,10 @@ "metadata": { "accelerator": "GPU", "colab": { - "provenance": [], - "machine_shape": "hm" + "machine_shape": "hm", + "provenance": [] }, + "gpuClass": "premium", "kernelspec": { "display_name": "Python 3", "language": "python", @@ -2126,101 +2131,12 @@ "pygments_lexer": "ipython3", "version": "3.7.3" }, - "gpuClass": "premium", "widgets": { "application/vnd.jupyter.widget-state+json": { - "1fb09eb6298c45c2a51493e72f5ff6eb": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_ef2dbddc5dc44348a37e7e9850dcf635", - "IPY_MODEL_28b5d66502164a5f96014fa374280ff1", - "IPY_MODEL_5d76c4d905df4e45918525b7b84a816a" - ], - "layout": "IPY_MODEL_5cba7b7d76a94e5eaeb9044d54b6009d" - } - }, - "ef2dbddc5dc44348a37e7e9850dcf635": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_9705adc0257943dea3ce436350fb2598", - "placeholder": "​", - "style": "IPY_MODEL_31274762f15e4c44b1812ed8fd0c1c17", - "value": "Downloading (…)lve/main/config.json: 100%" - } - }, - "28b5d66502164a5f96014fa374280ff1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_43f36a47df384583bc2a40b085406c3d", - "max": 694, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_20660bab65814893ab1216af51918912", - "value": 694 - } - }, - "5d76c4d905df4e45918525b7b84a816a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_eff0887e3b0b4d90afe79659e4ad471d", - "placeholder": "​", - "style": "IPY_MODEL_5def6b6f194f4c5f81e63c2652c5bfe6", - "value": " 694/694 [00:00<00:00, 62.3kB/s]" - } - }, - "5cba7b7d76a94e5eaeb9044d54b6009d": { + "0445cf8a7edf427abe09648d7e55f8ee": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2269,10 +2185,34 @@ "width": null } }, - "9705adc0257943dea3ce436350fb2598": { + "0898eff774d74e6987fe2f52b5efc497": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_95e5f2e112cf41ceac466c6afa015e1b", + "max": 597257159, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5132ab80c27e41e89f0eabd83302d3fc", + "value": 597257159 + } + }, + "0d563c824754464bb0c849e382b7128d": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2321,77 +2261,107 @@ "width": null } }, - "31274762f15e4c44b1812ed8fd0c1c17": { + "146e20e434984530a062fe69cca63ae2": { "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", + "bar_color": null, "description_width": "" } }, - "43f36a47df384583bc2a40b085406c3d": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", + "18fc8488734f432e9be491fdfcec8f40": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7d623d244e684f4482a02252b00e4883", + "placeholder": "​", + "style": "IPY_MODEL_e50a2630439f459b9bb7662e2f8e907a", + "value": "Downloading (…)olve/main/merges.txt: 100%" } }, - "20660bab65814893ab1216af51918912": { + "1b812bed073d4c2d94248d71339d28e9": { "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1e40fddf402e4937939bd3d8e72ae074": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b878c8c2fc31459d85496589d5f44dcb", + "IPY_MODEL_e4ad9261e026450787bd34d96e7a184a", + "IPY_MODEL_ee460c10113d42a096123589161cd6fe" + ], + "layout": "IPY_MODEL_f61fbad8ae4d4f50919aefad32257565" + } + }, + "1fb09eb6298c45c2a51493e72f5ff6eb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ef2dbddc5dc44348a37e7e9850dcf635", + "IPY_MODEL_28b5d66502164a5f96014fa374280ff1", + "IPY_MODEL_5d76c4d905df4e45918525b7b84a816a" + ], + "layout": "IPY_MODEL_5cba7b7d76a94e5eaeb9044d54b6009d" + } + }, + "20660bab65814893ab1216af51918912": { + "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", @@ -2404,10 +2374,10 @@ "description_width": "" } }, - "eff0887e3b0b4d90afe79659e4ad471d": { + "211b6d4cb5454703942b17b843c93196": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2456,113 +2426,85 @@ "width": null } }, - "5def6b6f194f4c5f81e63c2652c5bfe6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "1e40fddf402e4937939bd3d8e72ae074": { + "28b5d66502164a5f96014fa374280ff1": { "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", + "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_b878c8c2fc31459d85496589d5f44dcb", - "IPY_MODEL_e4ad9261e026450787bd34d96e7a184a", - "IPY_MODEL_ee460c10113d42a096123589161cd6fe" - ], - "layout": "IPY_MODEL_f61fbad8ae4d4f50919aefad32257565" + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_43f36a47df384583bc2a40b085406c3d", + "max": 694, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_20660bab65814893ab1216af51918912", + "value": 694 } }, - "b878c8c2fc31459d85496589d5f44dcb": { + "2e592022645b4898839ecf89dd4aec0e": { "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e9f349ae18a445bf9fc452f4fff92593", - "placeholder": "​", - "style": "IPY_MODEL_72cb227ec3f7409b8605ca674982a8c6", - "value": "Downloading (…)olve/main/vocab.json: 100%" + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" } }, - "e4ad9261e026450787bd34d96e7a184a": { + "2e6b245eeef64311921ce0d9ad29db3c": { "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", "model_module_version": "1.5.0", + "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", + "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", + "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_c91d67fea4d845369f5d9de3dd660407", - "max": 898823, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_1b812bed073d4c2d94248d71339d28e9", - "value": 898823 + "layout": "IPY_MODEL_4983642db7e6493cada0770b76c64de2", + "placeholder": "​", + "style": "IPY_MODEL_5ad5ad4c3abf4339add1bb374766d962", + "value": " 1.36M/1.36M [00:00<00:00, 14.3MB/s]" } }, - "ee460c10113d42a096123589161cd6fe": { + "31274762f15e4c44b1812ed8fd0c1c17": { "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_5407e56eab4b4a6f99001a0ed8528974", - "placeholder": "​", - "style": "IPY_MODEL_653ec47809804f36b92ae725ea60dc32", - "value": " 899k/899k [00:00<00:00, 2.72MB/s]" + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" } }, - "f61fbad8ae4d4f50919aefad32257565": { + "31705701df4348a3893382fdd2d5a600": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2611,10 +2553,10 @@ "width": null } }, - "e9f349ae18a445bf9fc452f4fff92593": { + "43f36a47df384583bc2a40b085406c3d": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2663,25 +2605,10 @@ "width": null } }, - "72cb227ec3f7409b8605ca674982a8c6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "c91d67fea4d845369f5d9de3dd660407": { + "4983642db7e6493cada0770b76c64de2": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2730,26 +2657,10 @@ "width": null } }, - "1b812bed073d4c2d94248d71339d28e9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "5407e56eab4b4a6f99001a0ed8528974": { + "4d78302a2fc3449f9db3df9c8500fa99": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2798,25 +2709,26 @@ "width": null } }, - "653ec47809804f36b92ae725ea60dc32": { + "5132ab80c27e41e89f0eabd83302d3fc": { "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", + "bar_color": null, "description_width": "" } }, - "c8651597aad743b380a6bd8a24490fa4": { + "522eceeae9f9465081d6904705b82e5e": { "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", "model_module_version": "1.5.0", + "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", @@ -2828,83 +2740,17 @@ "_view_name": "HBoxView", "box_style": "", "children": [ - "IPY_MODEL_18fc8488734f432e9be491fdfcec8f40", - "IPY_MODEL_bbf4462a91de4286a6352005c79520e9", - "IPY_MODEL_6e67e2354adf495bb2d29aa44c9ca344" + "IPY_MODEL_7b43ea4a0be74799a574088d69a3bcd9", + "IPY_MODEL_0898eff774d74e6987fe2f52b5efc497", + "IPY_MODEL_ff11eaaaaf754727ac565af3f705dfcf" ], - "layout": "IPY_MODEL_31705701df4348a3893382fdd2d5a600" - } - }, - "18fc8488734f432e9be491fdfcec8f40": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7d623d244e684f4482a02252b00e4883", - "placeholder": "​", - "style": "IPY_MODEL_e50a2630439f459b9bb7662e2f8e907a", - "value": "Downloading (…)olve/main/merges.txt: 100%" - } - }, - "bbf4462a91de4286a6352005c79520e9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_94be463bf8634a47ba3155b1b8474756", - "max": 456318, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_9cb14500689f4ae696a6cff5a680335a", - "value": 456318 - } - }, - "6e67e2354adf495bb2d29aa44c9ca344": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_0445cf8a7edf427abe09648d7e55f8ee", - "placeholder": "​", - "style": "IPY_MODEL_8d5bfad75af047358592ce5e5c5d0a2e", - "value": " 456k/456k [00:00<00:00, 2.64MB/s]" + "layout": "IPY_MODEL_0d563c824754464bb0c849e382b7128d" } }, - "31705701df4348a3893382fdd2d5a600": { + "5407e56eab4b4a6f99001a0ed8528974": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -2953,10 +2799,25 @@ "width": null } }, - "7d623d244e684f4482a02252b00e4883": { + "5ad5ad4c3abf4339add1bb374766d962": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "5cba7b7d76a94e5eaeb9044d54b6009d": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3005,10 +2866,85 @@ "width": null } }, - "e50a2630439f459b9bb7662e2f8e907a": { + "5d76c4d905df4e45918525b7b84a816a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_eff0887e3b0b4d90afe79659e4ad471d", + "placeholder": "​", + "style": "IPY_MODEL_5def6b6f194f4c5f81e63c2652c5bfe6", + "value": " 694/694 [00:00<00:00, 62.3kB/s]" + } + }, + "5def6b6f194f4c5f81e63c2652c5bfe6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "653ec47809804f36b92ae725ea60dc32": { "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "6cdbb49414c14abd9ae457deb2e61c2e": { + "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9d43ccd97f5f4a6b9a6c3bb32b8fd760", + "max": 1355863, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_146e20e434984530a062fe69cca63ae2", + "value": 1355863 + } + }, + "6ded2ac25e2843509d86dc3e2677c506": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", @@ -3020,10 +2956,82 @@ "description_width": "" } }, - "94be463bf8634a47ba3155b1b8474756": { + "6e67e2354adf495bb2d29aa44c9ca344": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0445cf8a7edf427abe09648d7e55f8ee", + "placeholder": "​", + "style": "IPY_MODEL_8d5bfad75af047358592ce5e5c5d0a2e", + "value": " 456k/456k [00:00<00:00, 2.64MB/s]" + } + }, + "72cb227ec3f7409b8605ca674982a8c6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "773870f3767d44cea0bb61789b6e7524": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "7b43ea4a0be74799a574088d69a3bcd9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_211b6d4cb5454703942b17b843c93196", + "placeholder": "​", + "style": "IPY_MODEL_6ded2ac25e2843509d86dc3e2677c506", + "value": "Downloading pytorch_model.bin: 100%" + } + }, + "7d623d244e684f4482a02252b00e4883": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3072,26 +3080,25 @@ "width": null } }, - "9cb14500689f4ae696a6cff5a680335a": { + "8d5bfad75af047358592ce5e5c5d0a2e": { "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", - "bar_color": null, "description_width": "" } }, - "0445cf8a7edf427abe09648d7e55f8ee": { + "94be463bf8634a47ba3155b1b8474756": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3140,113 +3147,10 @@ "width": null } }, - "8d5bfad75af047358592ce5e5c5d0a2e": { - "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "bb31de1c0d2147ba84c669ced99bb2b4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HBoxModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_ea40a62ca0f841ea9248323891068bd7", - "IPY_MODEL_6cdbb49414c14abd9ae457deb2e61c2e", - "IPY_MODEL_2e6b245eeef64311921ce0d9ad29db3c" - ], - "layout": "IPY_MODEL_a1c452159f274cf882b081dfb7c6c357" - } - }, - "ea40a62ca0f841ea9248323891068bd7": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_4d78302a2fc3449f9db3df9c8500fa99", - "placeholder": "​", - "style": "IPY_MODEL_773870f3767d44cea0bb61789b6e7524", - "value": "Downloading (…)/main/tokenizer.json: 100%" - } - }, - "6cdbb49414c14abd9ae457deb2e61c2e": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_9d43ccd97f5f4a6b9a6c3bb32b8fd760", - "max": 1355863, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_146e20e434984530a062fe69cca63ae2", - "value": 1355863 - } - }, - "2e6b245eeef64311921ce0d9ad29db3c": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_4983642db7e6493cada0770b76c64de2", - "placeholder": "​", - "style": "IPY_MODEL_5ad5ad4c3abf4339add1bb374766d962", - "value": " 1.36M/1.36M [00:00<00:00, 14.3MB/s]" - } - }, - "a1c452159f274cf882b081dfb7c6c357": { + "95e5f2e112cf41ceac466c6afa015e1b": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3295,10 +3199,10 @@ "width": null } }, - "4d78302a2fc3449f9db3df9c8500fa99": { + "9705adc0257943dea3ce436350fb2598": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3347,25 +3251,78 @@ "width": null } }, - "773870f3767d44cea0bb61789b6e7524": { + "9cb14500689f4ae696a6cff5a680335a": { "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", + "bar_color": null, "description_width": "" } }, "9d43ccd97f5f4a6b9a6c3bb32b8fd760": { "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a1c452159f274cf882b081dfb7c6c357": { + "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3414,26 +3371,77 @@ "width": null } }, - "146e20e434984530a062fe69cca63ae2": { + "b878c8c2fc31459d85496589d5f44dcb": { "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", + "model_name": "HTMLModel", "state": { + "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", + "_model_name": "HTMLModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e9f349ae18a445bf9fc452f4fff92593", + "placeholder": "​", + "style": "IPY_MODEL_72cb227ec3f7409b8605ca674982a8c6", + "value": "Downloading (…)olve/main/vocab.json: 100%" + } + }, + "bb31de1c0d2147ba84c669ced99bb2b4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_ea40a62ca0f841ea9248323891068bd7", + "IPY_MODEL_6cdbb49414c14abd9ae457deb2e61c2e", + "IPY_MODEL_2e6b245eeef64311921ce0d9ad29db3c" + ], + "layout": "IPY_MODEL_a1c452159f274cf882b081dfb7c6c357" + } + }, + "bbf4462a91de4286a6352005c79520e9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_94be463bf8634a47ba3155b1b8474756", + "max": 456318, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9cb14500689f4ae696a6cff5a680335a", + "value": 456318 } }, - "4983642db7e6493cada0770b76c64de2": { + "c42bedf951844e94bd0399783e69da14": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3482,25 +3490,10 @@ "width": null } }, - "5ad5ad4c3abf4339add1bb374766d962": { + "c8651597aad743b380a6bd8a24490fa4": { "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "522eceeae9f9465081d6904705b82e5e": { - "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", - "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", @@ -3512,83 +3505,17 @@ "_view_name": "HBoxView", "box_style": "", "children": [ - "IPY_MODEL_7b43ea4a0be74799a574088d69a3bcd9", - "IPY_MODEL_0898eff774d74e6987fe2f52b5efc497", - "IPY_MODEL_ff11eaaaaf754727ac565af3f705dfcf" + "IPY_MODEL_18fc8488734f432e9be491fdfcec8f40", + "IPY_MODEL_bbf4462a91de4286a6352005c79520e9", + "IPY_MODEL_6e67e2354adf495bb2d29aa44c9ca344" ], - "layout": "IPY_MODEL_0d563c824754464bb0c849e382b7128d" - } - }, - "7b43ea4a0be74799a574088d69a3bcd9": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_211b6d4cb5454703942b17b843c93196", - "placeholder": "​", - "style": "IPY_MODEL_6ded2ac25e2843509d86dc3e2677c506", - "value": "Downloading pytorch_model.bin: 100%" - } - }, - "0898eff774d74e6987fe2f52b5efc497": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatProgressModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_95e5f2e112cf41ceac466c6afa015e1b", - "max": 597257159, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_5132ab80c27e41e89f0eabd83302d3fc", - "value": 597257159 - } - }, - "ff11eaaaaf754727ac565af3f705dfcf": { - "model_module": "@jupyter-widgets/controls", - "model_name": "HTMLModel", - "model_module_version": "1.5.0", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_c42bedf951844e94bd0399783e69da14", - "placeholder": "​", - "style": "IPY_MODEL_2e592022645b4898839ecf89dd4aec0e", - "value": " 597M/597M [00:02<00:00, 306MB/s]" + "layout": "IPY_MODEL_31705701df4348a3893382fdd2d5a600" } }, - "0d563c824754464bb0c849e382b7128d": { + "c91d67fea4d845369f5d9de3dd660407": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3637,10 +3564,49 @@ "width": null } }, - "211b6d4cb5454703942b17b843c93196": { + "e4ad9261e026450787bd34d96e7a184a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c91d67fea4d845369f5d9de3dd660407", + "max": 898823, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1b812bed073d4c2d94248d71339d28e9", + "value": 898823 + } + }, + "e50a2630439f459b9bb7662e2f8e907a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e9f349ae18a445bf9fc452f4fff92593": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3689,25 +3655,73 @@ "width": null } }, - "6ded2ac25e2843509d86dc3e2677c506": { + "ea40a62ca0f841ea9248323891068bd7": { "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", + "model_name": "HTMLModel", "state": { + "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "HTMLModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4d78302a2fc3449f9db3df9c8500fa99", + "placeholder": "​", + "style": "IPY_MODEL_773870f3767d44cea0bb61789b6e7524", + "value": "Downloading (…)/main/tokenizer.json: 100%" } }, - "95e5f2e112cf41ceac466c6afa015e1b": { + "ee460c10113d42a096123589161cd6fe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5407e56eab4b4a6f99001a0ed8528974", + "placeholder": "​", + "style": "IPY_MODEL_653ec47809804f36b92ae725ea60dc32", + "value": " 899k/899k [00:00<00:00, 2.72MB/s]" + } + }, + "ef2dbddc5dc44348a37e7e9850dcf635": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9705adc0257943dea3ce436350fb2598", + "placeholder": "​", + "style": "IPY_MODEL_31274762f15e4c44b1812ed8fd0c1c17", + "value": "Downloading (…)lve/main/config.json: 100%" + } + }, + "eff0887e3b0b4d90afe79659e4ad471d": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3756,26 +3770,10 @@ "width": null } }, - "5132ab80c27e41e89f0eabd83302d3fc": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ProgressStyleModel", - "model_module_version": "1.5.0", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "c42bedf951844e94bd0399783e69da14": { + "f61fbad8ae4d4f50919aefad32257565": { "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", @@ -3824,19 +3822,25 @@ "width": null } }, - "2e592022645b4898839ecf89dd4aec0e": { + "ff11eaaaaf754727ac565af3f705dfcf": { "model_module": "@jupyter-widgets/controls", - "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", + "model_name": "HTMLModel", "state": { + "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "HTMLModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c42bedf951844e94bd0399783e69da14", + "placeholder": "​", + "style": "IPY_MODEL_2e592022645b4898839ecf89dd4aec0e", + "value": " 597M/597M [00:02<00:00, 306MB/s]" } } } diff --git a/notebooks/text_w_tabular_classification.ipynb b/notebooks/text_w_tabular_classification.ipynb index 38bb563..657efaa 100644 --- a/notebooks/text_w_tabular_classification.ipynb +++ b/notebooks/text_w_tabular_classification.ipynb @@ -868,6 +868,10 @@ " metadata={\n", " 'help': 'whether to use batchnorm on numerical features'\n", " })\n", + " categorical_bn: bool = field(default=True,\n", + " metadata={\n", + " 'help': 'whether to use batchnorm on categorical features'\n", + " })\n", " use_simple_classifier: str = field(default=True,\n", " metadata={\n", " 'help': 'whether to use single layer or MLP as final classifier'\n", From 5bbc06f4e2ba1bdbab7c164d06bebb4f427da4df Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Mon, 30 Oct 2023 11:57:59 -0400 Subject: [PATCH 05/12] Fix(multimodal_exp_args.py): Resolve #34 --- multimodal_exp_args.py | 42 ++---------------------------------------- 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/multimodal_exp_args.py b/multimodal_exp_args.py index 0da050f..5fe0a3d 100644 --- a/multimodal_exp_args.py +++ b/multimodal_exp_args.py @@ -153,11 +153,6 @@ class OurTrainingArguments(TrainingArguments): metadata={'help': 'A name for the experiment'} ) - gpu_num: int = field( - default=0, - metadata={'help': 'The gpu number to train on'} - ) - debug_dataset: bool = field( default=False, metadata={'help': 'Whether we are training in debug mode (smaller model)'} @@ -187,42 +182,9 @@ class OurTrainingArguments(TrainingArguments): learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) - report_to: Optional[List[str]] = field( - default_factory=list, metadata={"help": "The list of integrations to report the results and logs to."} - ) - def __post_init__(self): + super().__post_init__() if self.debug_dataset: self.max_token_length = 16 self.logging_steps = 5 - self.overwrite_output_dir = True - - - @cached_property - def _setup_devices(self) -> Tuple["torch.device", int]: - requires_backends(self, ["torch"]) - logger.info("PyTorch: setting up devices") - if self.no_cuda: - device = torch.device("cpu") - self._n_gpu = 0 - elif self.local_rank == -1: - # if n_gpu is > 1 we'll use nn.DataParallel. - # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` - # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will - # trigger an error that a device index is missing. Index 0 takes into account the - # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` - # will use the first GPU in that env, i.e. GPU#1 - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self._n_gpu = torch.cuda.device_count() - else: - # Here, we'll use torch.distributed. - # Initializes the distributed backend which will take care of sychronizing nodes/GPUs - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta) - device = torch.device("cuda", self.local_rank) - self._n_gpu = 1 - - if device.type == "cuda": - torch.cuda.set_device(device) - - return device \ No newline at end of file + self.overwrite_output_dir = True \ No newline at end of file From 38ec78961feba9b05fdddcac9fe9ca5e0e104dc8 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Mon, 30 Oct 2023 12:02:14 -0400 Subject: [PATCH 06/12] Fix(multimodal_transformers/model/tabular_transformers.py): Resolve #46 --- multimodal_transformers/model/tabular_transformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/multimodal_transformers/model/tabular_transformers.py b/multimodal_transformers/model/tabular_transformers.py index 2853c2a..063fa33 100644 --- a/multimodal_transformers/model/tabular_transformers.py +++ b/multimodal_transformers/model/tabular_transformers.py @@ -492,6 +492,7 @@ def __init__(self, hf_model_config): self.config.tabular_config = tabular_config.__dict__ tabular_config.text_feat_dim = hf_model_config.hidden_size + tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob self.tabular_combiner = TabularFeatCombiner(tabular_config) self.num_labels = tabular_config.num_labels combined_feat_dim = self.tabular_combiner.final_out_dim @@ -603,6 +604,7 @@ def __init__(self, hf_model_config): self.config.tabular_config = tabular_config.__dict__ tabular_config.text_feat_dim = hf_model_config.hidden_size + tabular_config.hidden_dropout_prob = hf_model_config.hidden_dropout_prob self.tabular_combiner = TabularFeatCombiner(tabular_config) self.num_labels = tabular_config.num_labels combined_feat_dim = self.tabular_combiner.final_out_dim From 5bab76d4529aca3de5fc6a4899067998ced0fa27 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Mon, 30 Oct 2023 13:43:15 -0400 Subject: [PATCH 07/12] Fix(setup.py): Library update to handle bugs/security/maintenance. --- setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index c7476b6..0f356b6 100644 --- a/setup.py +++ b/setup.py @@ -24,14 +24,14 @@ def get_version(rel_path: str) -> str: url = "https://github.com/georgianpartners/Multimodal-Toolkit" install_requires = [ - "transformers>=4.26.1", - "torch>=1.13.1", + "transformers>=4.34.1", + "torch>=2.0.1", "sacremoses~=0.0.53", "networkx~=2.6.3", "scikit-learn~=1.0.2", - "scipy~=1.7.3", + "scipy~=1.11.3", "pandas~=1.3.5", - "numpy~=1.21.6", + "numpy~=1.26.1", "tqdm~=4.64.1", "pytest~=7.2.2", ] @@ -58,7 +58,7 @@ def get_version(rel_path: str) -> str: install_requires=install_requires, python_requires=">=3.7", classifiers=[ - "Development Status :: 3 - Alpha", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package + "Development Status :: 5 - Production/Stable", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package "Intended Audience :: Developers", # Define that your audience are developers "Topic :: Software Development :: Build Tools", "License :: OSI Approved :: MIT License", # Again, pick a license From 2d9ff39174f959b4440561a1bf4590a0bb8fde73 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Mon, 30 Oct 2023 13:45:42 -0400 Subject: [PATCH 08/12] Docs(.): Update version; thanks to #33 --- multimodal_transformers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/multimodal_transformers/__init__.py b/multimodal_transformers/__init__.py index 5b016a2..7a2920d 100644 --- a/multimodal_transformers/__init__.py +++ b/multimodal_transformers/__init__.py @@ -1,6 +1,6 @@ import multimodal_transformers.data import multimodal_transformers.model -__version__ = "0.2-alpha" +__version__ = "0.3.0" __all__ = ["multimodal_transformers", "__version__"] From 6f459ec110d27a702184502f74cc549fa364678d Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Tue, 7 Nov 2023 12:59:09 -0700 Subject: [PATCH 09/12] Fix(setup.py): Add accelerate to reqs. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 0f356b6..f31165e 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ def get_version(rel_path: str) -> str: install_requires = [ "transformers>=4.34.1", + "accelerate>=0.24.1", "torch>=2.0.1", "sacremoses~=0.0.53", "networkx~=2.6.3", From d749fc19f8586445ba178ef672eaef41e9efaed5 Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Tue, 7 Nov 2023 12:59:37 -0700 Subject: [PATCH 10/12] Fix(multimodal_exp_args.py): Remove invalid param. --- multimodal_exp_args.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/multimodal_exp_args.py b/multimodal_exp_args.py index 5fe0a3d..8729071 100644 --- a/multimodal_exp_args.py +++ b/multimodal_exp_args.py @@ -166,10 +166,6 @@ class OurTrainingArguments(TrainingArguments): do_eval: bool = field(default=True, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=True, metadata={"help": "Whether to run predictions on the test set."}) - evaluate_during_training: bool = field( - default=True, metadata={"help": "Run evaluation during training at each logging step."}, - ) - max_token_length: Optional[int] = field( default=None, metadata={'help': 'The maximum token length'} From f2a95ee9bb7ab68a24f0704ce61b6e2cd30ee13e Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Tue, 7 Nov 2023 13:01:01 -0700 Subject: [PATCH 11/12] Refactor(multimodal_exp_args.py): Format with black. --- multimodal_exp_args.py | 278 +++++++++++++++++++++++++---------------- 1 file changed, 167 insertions(+), 111 deletions(-) diff --git a/multimodal_exp_args.py b/multimodal_exp_args.py index 8729071..0c3ff25 100644 --- a/multimodal_exp_args.py +++ b/multimodal_exp_args.py @@ -4,7 +4,11 @@ from typing import Optional, Tuple, List import torch -from transformers.training_args import TrainingArguments, requires_backends, cached_property +from transformers.training_args import ( + TrainingArguments, + requires_backends, + cached_property, +) logger = logging.getLogger(__name__) @@ -17,20 +21,30 @@ class ModelArguments: """ model_name_or_path: str = field( - metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models" + } ) config_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, ) tokenizer_name: Optional[str] = field( - default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, ) cache_dir: Optional[str] = field( - default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from s3" + }, ) - @dataclass class MultimodalDataTrainingArguments: """ @@ -40,147 +54,189 @@ class MultimodalDataTrainingArguments: the command line. """ - data_path: str = field(metadata={ - 'help': 'the path to the csv files containing the dataset. If create_folds is set to True' - 'then it is expected that data_path points to one csv containing the entire dataset' - 'to split into folds. Otherwise, data_path should be the folder containing' - 'train.csv, test.csv, (and val.csv if available)' - }) - create_folds: bool = field(default=False, - metadata={'help': 'Whether or not we want to create folds for ' - 'K fold evaluation of the model'}) - - num_folds: int = field(default=5, - metadata={'help': 'The number of folds for K fold ' - 'evaluation of the model. Will not be used if create_folds is False'}) - validation_ratio: float = field(default=0.2, - metadata={'help': 'The ratio of dataset examples to be used for validation across' - 'all folds for K fold evaluation. If num_folds is 5 and ' - 'validation_ratio is 0.2. Then a consistent 20% of the examples will' - 'be used for validation for all folds. Then the remaining 80% is used' - 'for K fold split for test and train sets so 0.2*0.8=16% of ' - 'all examples is used for testing and 0.8*0.8=64% of all examples' - 'is used for training for each fold'} - ) - num_classes: int = field(default=-1, - metadata={'help': 'Number of labels for classification if any'}) + data_path: str = field( + metadata={ + "help": "the path to the csv files containing the dataset. If create_folds is set to True" + "then it is expected that data_path points to one csv containing the entire dataset" + "to split into folds. Otherwise, data_path should be the folder containing" + "train.csv, test.csv, (and val.csv if available)" + } + ) + create_folds: bool = field( + default=False, + metadata={ + "help": "Whether or not we want to create folds for " + "K fold evaluation of the model" + }, + ) + + num_folds: int = field( + default=5, + metadata={ + "help": "The number of folds for K fold " + "evaluation of the model. Will not be used if create_folds is False" + }, + ) + validation_ratio: float = field( + default=0.2, + metadata={ + "help": "The ratio of dataset examples to be used for validation across" + "all folds for K fold evaluation. If num_folds is 5 and " + "validation_ratio is 0.2. Then a consistent 20% of the examples will" + "be used for validation for all folds. Then the remaining 80% is used" + "for K fold split for test and train sets so 0.2*0.8=16% of " + "all examples is used for testing and 0.8*0.8=64% of all examples" + "is used for training for each fold" + }, + ) + num_classes: int = field( + default=-1, metadata={"help": "Number of labels for classification if any"} + ) column_info_path: str = field( default=None, metadata={ - 'help': 'the path to the json file detailing which columns are text, categorical, numerical, and the label' - }) + "help": "the path to the json file detailing which columns are text, categorical, numerical, and the label" + }, + ) column_info: dict = field( default=None, metadata={ - 'help': 'a dict referencing the text, categorical, numerical, and label columns' - 'its keys are text_cols, num_cols, cat_cols, and label_col' - }) - - categorical_encode_type: str = field(default='ohe', - metadata={ - 'help': 'sklearn encoder to use for categorical data', - 'choices': ['ohe', 'binary', 'label', 'none'] - }) - numerical_transformer_method: str = field(default='yeo_johnson', - metadata={ - 'help': 'sklearn numerical transformer to preprocess numerical data', - 'choices': ['yeo_johnson', 'box_cox', 'quantile_normal', 'none'] - }) - task: str = field(default="classification", - metadata={ - "help": "The downstream training task", - "choices": ["classification", "regression"] - }) - - mlp_division: int = field(default=4, - metadata={ - 'help': 'the ratio of the number of ' - 'hidden dims in a current layer to the next MLP layer' - }) - combine_feat_method: str = field(default='individual_mlps_on_cat_and_numerical_feats_then_concat', - metadata={ - 'help': 'method to combine categorical and numerical features, ' - 'see README for all the method' - }) - mlp_dropout: float = field(default=0.1, - metadata={ - 'help': 'dropout ratio used for MLP layers' - }) - numerical_bn: bool = field(default=True, - metadata={ - 'help': 'whether to use batchnorm on numerical features' - }) - categorical_bn: bool = field(default=True, - metadata={ - 'help': 'whether to use batchnorm on categorical features' - }) - use_simple_classifier: str = field(default=True, - metadata={ - 'help': 'whether to use single layer or MLP as final classifier' - }) - mlp_act: str = field(default='relu', - metadata={ - 'help': 'the activation function to use for finetuning layers', - 'choices': ['relu', 'prelu', 'sigmoid', 'tanh', 'linear'] - }) - gating_beta: float = field(default=0.2, - metadata={ - 'help': "the beta hyperparameters used for gating tabular data " - "see https://www.aclweb.org/anthology/2020.acl-main.214.pdf" - }) + "help": "a dict referencing the text, categorical, numerical, and label columns" + "its keys are text_cols, num_cols, cat_cols, and label_col" + }, + ) + + categorical_encode_type: str = field( + default="ohe", + metadata={ + "help": "sklearn encoder to use for categorical data", + "choices": ["ohe", "binary", "label", "none"], + }, + ) + numerical_transformer_method: str = field( + default="yeo_johnson", + metadata={ + "help": "sklearn numerical transformer to preprocess numerical data", + "choices": ["yeo_johnson", "box_cox", "quantile_normal", "none"], + }, + ) + task: str = field( + default="classification", + metadata={ + "help": "The downstream training task", + "choices": ["classification", "regression"], + }, + ) + + mlp_division: int = field( + default=4, + metadata={ + "help": "the ratio of the number of " + "hidden dims in a current layer to the next MLP layer" + }, + ) + combine_feat_method: str = field( + default="individual_mlps_on_cat_and_numerical_feats_then_concat", + metadata={ + "help": "method to combine categorical and numerical features, " + "see README for all the method" + }, + ) + mlp_dropout: float = field( + default=0.1, metadata={"help": "dropout ratio used for MLP layers"} + ) + numerical_bn: bool = field( + default=True, + metadata={"help": "whether to use batchnorm on numerical features"}, + ) + categorical_bn: bool = field( + default=True, + metadata={"help": "whether to use batchnorm on categorical features"}, + ) + use_simple_classifier: str = field( + default=True, + metadata={"help": "whether to use single layer or MLP as final classifier"}, + ) + mlp_act: str = field( + default="relu", + metadata={ + "help": "the activation function to use for finetuning layers", + "choices": ["relu", "prelu", "sigmoid", "tanh", "linear"], + }, + ) + gating_beta: float = field( + default=0.2, + metadata={ + "help": "the beta hyperparameters used for gating tabular data " + "see https://www.aclweb.org/anthology/2020.acl-main.214.pdf" + }, + ) def __post_init__(self): - assert self.column_info != self.column_info_path, 'provide either a path to column_info or a dictionary' - assert 0 <= self.validation_ratio <= 1, 'validation_ratio must be between 0 and 1' + assert ( + self.column_info != self.column_info_path + ), "provide either a path to column_info or a dictionary" + assert ( + 0 <= self.validation_ratio <= 1 + ), "validation_ratio must be between 0 and 1" if self.column_info is None and self.column_info_path: - with open(self.column_info_path, 'r') as f: + with open(self.column_info_path, "r") as f: self.column_info = json.load(f) - assert 'text_cols' in self.column_info and 'label_col' in self.column_info - if 'cat_cols' not in self.column_info: - self.column_info['cat_cols'] = None - self.categorical_encode_type = 'none' - if 'num_cols' not in self.column_info: - self.column_info['num_cols'] = None - self.numerical_transformer_method = 'none' - if 'text_col_sep_token' not in self.column_info: - self.column_info['text_col_sep_token'] = None + assert "text_cols" in self.column_info and "label_col" in self.column_info + if "cat_cols" not in self.column_info: + self.column_info["cat_cols"] = None + self.categorical_encode_type = "none" + if "num_cols" not in self.column_info: + self.column_info["num_cols"] = None + self.numerical_transformer_method = "none" + if "text_col_sep_token" not in self.column_info: + self.column_info["text_col_sep_token"] = None + @dataclass class OurTrainingArguments(TrainingArguments): experiment_name: Optional[str] = field( - default=None, - metadata={'help': 'A name for the experiment'} + default=None, metadata={"help": "A name for the experiment"} ) debug_dataset: bool = field( default=False, - metadata={'help': 'Whether we are training in debug mode (smaller model)'} + metadata={"help": "Whether we are training in debug mode (smaller model)"}, ) debug_dataset_size: int = field( default=100, - metadata={'help': 'Size of the dataset in debug mode. Only used when debug_dataset = True.'} + metadata={ + "help": "Size of the dataset in debug mode. Only used when debug_dataset = True." + }, ) - do_eval: bool = field(default=True, metadata={"help": "Whether to run eval on the dev set."}) - do_predict: bool = field(default=True, metadata={"help": "Whether to run predictions on the test set."}) + do_eval: bool = field( + default=True, metadata={"help": "Whether to run eval on the dev set."} + ) + do_predict: bool = field( + default=True, metadata={"help": "Whether to run predictions on the test set."} + ) max_token_length: Optional[int] = field( - default=None, - metadata={'help': 'The maximum token length'} + default=None, metadata={"help": "The maximum token length"} ) gradient_accumulation_steps: int = field( default=1, - metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + metadata={ + "help": "Number of updates steps to accumulate before performing a backward/update pass." + }, ) - learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) + learning_rate: float = field( + default=5e-5, metadata={"help": "The initial learning rate for Adam."} + ) def __post_init__(self): super().__post_init__() if self.debug_dataset: self.max_token_length = 16 self.logging_steps = 5 - self.overwrite_output_dir = True \ No newline at end of file + self.overwrite_output_dir = True From 5b0bffe06541475eacfa163b84d1b5cef77a760e Mon Sep 17 00:00:00 2001 From: Akash Saravanan Date: Tue, 7 Nov 2023 13:47:37 -0700 Subject: [PATCH 12/12] Fix(setup.py): update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f31165e..c8128b1 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def get_version(rel_path: str) -> str: install_requires=install_requires, python_requires=">=3.7", classifiers=[ - "Development Status :: 5 - Production/Stable", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package + "Development Status :: 4 - Beta", # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package "Intended Audience :: Developers", # Define that your audience are developers "Topic :: Software Development :: Build Tools", "License :: OSI Approved :: MIT License", # Again, pick a license