-
Notifications
You must be signed in to change notification settings - Fork 85
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from georgian-io/akash/better-preprocessing
Feat: Better Preprocessing
- Loading branch information
Showing
16 changed files
with
719 additions
and
287 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
""" | ||
This is an example of how to use the toolkit to run inference. | ||
We use saved feature processors (generated by code in the tests folder). | ||
This code assumes you have a dataframe of datapoints and batches them up. | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
sys.path.append("./") | ||
|
||
import joblib | ||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, set_seed | ||
|
||
from multimodal_exp_args import ( | ||
ModelArguments, | ||
MultimodalDataTrainingArguments, | ||
OurTrainingArguments, | ||
) | ||
from multimodal_transformers.data import load_data | ||
from multimodal_transformers.model import AutoModelWithTabular, TabularConfig | ||
|
||
if __name__ == "__main__": | ||
DEBUG = True | ||
DEBUG_DATASET_SIZE = 50 | ||
JSON_FILE = "./tests/test_airbnb.json" | ||
MODEL_SAVE_DIR = "./logs_airbnb/bertmultilingual_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3/" | ||
NUMERICAL_TRANSFORMER_PATH = os.path.join( | ||
MODEL_SAVE_DIR, "numerical_transformer.pkl" | ||
) | ||
CATEGORICAL_TRANSFORMER_PATH = os.path.join( | ||
MODEL_SAVE_DIR, "categorical_transformer.pkl" | ||
) | ||
MODEL_CONFIG_PATH = os.path.join(MODEL_SAVE_DIR, "config.json") | ||
MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "model.safetensors") | ||
|
||
# Parse our input json files | ||
parser = HfArgumentParser( | ||
(ModelArguments, MultimodalDataTrainingArguments, OurTrainingArguments) | ||
) | ||
model_args, data_args, training_args = parser.parse_json_file( | ||
json_file=os.path.abspath(JSON_FILE) | ||
) | ||
|
||
# Set random seed for reproducibility | ||
set_seed(training_args.seed) | ||
|
||
# Create a tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
( | ||
model_args.tokenizer_name | ||
if model_args.tokenizer_name | ||
else model_args.model_name_or_path | ||
), | ||
cache_dir=model_args.cache_dir, | ||
) | ||
|
||
# Load our feature processors | ||
categorical_transformer = joblib.load(CATEGORICAL_TRANSFORMER_PATH) | ||
numerical_transformer = joblib.load(NUMERICAL_TRANSFORMER_PATH) | ||
|
||
# Load our test set | ||
data_df = pd.read_csv(os.path.join(data_args.data_path, "test.csv")) | ||
|
||
# Load and preprocess our test dataset | ||
test_dataset = load_data( | ||
data_df=data_df, | ||
text_cols=data_args.column_info["text_cols"], | ||
tokenizer=tokenizer, | ||
label_col=data_args.column_info["label_col"], | ||
label_list=data_args.column_info["label_list"], | ||
categorical_cols=data_args.column_info["cat_cols"], | ||
numerical_cols=data_args.column_info["num_cols"], | ||
sep_text_token_str=( | ||
tokenizer.sep_token | ||
if not data_args.column_info["text_col_sep_token"] | ||
else data_args.column_info["text_col_sep_token"] | ||
), | ||
categorical_transformer=categorical_transformer, | ||
numerical_transformer=numerical_transformer, | ||
max_token_length=training_args.max_token_length, | ||
debug=DEBUG, | ||
debug_dataset_size=DEBUG_DATASET_SIZE, | ||
) | ||
|
||
task = data_args.task | ||
# Regression tasks have only one "label" | ||
if task == "regression": | ||
num_labels = 1 | ||
else: | ||
num_labels = ( | ||
len(np.unique(test_dataset.labels)) | ||
if data_args.num_classes == -1 | ||
else data_args.num_classes | ||
) | ||
|
||
# Setup configs | ||
config = AutoConfig.from_pretrained( | ||
MODEL_CONFIG_PATH, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
tabular_config = TabularConfig( | ||
num_labels=num_labels, | ||
cat_feat_dim=( | ||
test_dataset.cat_feats.shape[1] if test_dataset.cat_feats is not None else 0 | ||
), | ||
numerical_feat_dim=( | ||
test_dataset.numerical_feats.shape[1] | ||
if test_dataset.numerical_feats is not None | ||
else 0 | ||
), | ||
**vars(data_args), | ||
) | ||
config.tabular_config = tabular_config | ||
|
||
# Make model | ||
model = AutoModelWithTabular.from_pretrained( | ||
MODEL_PATH, | ||
config=config, | ||
cache_dir=model_args.cache_dir, | ||
) | ||
|
||
# Run inference | ||
dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16) | ||
model.eval() | ||
all_labels = [] | ||
all_preds = [] | ||
with torch.no_grad(): | ||
for batch in dataloader: | ||
_, logits, classifier_outputs = model( | ||
batch["input_ids"], | ||
attention_mask=batch["attention_mask"], | ||
token_type_ids=batch["token_type_ids"], | ||
cat_feats=batch["cat_feats"], | ||
numerical_feats=batch["numerical_feats"], | ||
) | ||
all_labels.append(batch["labels"]) | ||
all_preds.append(logits.argmax(axis=1)) | ||
|
||
all_preds = torch.cat(all_preds) | ||
all_labels = torch.cat(all_labels) | ||
acc = torch.sum(all_preds == all_labels) / all_labels.shape[0] | ||
print(f"Accuracy: {acc}") |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.