diff --git a/examples/inference.py b/examples/inference.py new file mode 100644 index 0000000..571c503 --- /dev/null +++ b/examples/inference.py @@ -0,0 +1,146 @@ +""" +This is an example of how to use the toolkit to run inference. +We use saved feature processors (generated by code in the tests folder). +This code assumes you have a dataframe of datapoints and batches them up. +""" + +import os +import sys + +sys.path.append("./") + +import joblib +import numpy as np +import pandas as pd +import torch +from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, set_seed + +from multimodal_exp_args import ( + ModelArguments, + MultimodalDataTrainingArguments, + OurTrainingArguments, +) +from multimodal_transformers.data import load_data +from multimodal_transformers.model import AutoModelWithTabular, TabularConfig + +if __name__ == "__main__": + DEBUG = True + DEBUG_DATASET_SIZE = 50 + JSON_FILE = "./tests/test_airbnb.json" + MODEL_SAVE_DIR = "./logs_airbnb/bertmultilingual_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3/" + NUMERICAL_TRANSFORMER_PATH = os.path.join( + MODEL_SAVE_DIR, "numerical_transformer.pkl" + ) + CATEGORICAL_TRANSFORMER_PATH = os.path.join( + MODEL_SAVE_DIR, "categorical_transformer.pkl" + ) + MODEL_CONFIG_PATH = os.path.join(MODEL_SAVE_DIR, "config.json") + MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "model.safetensors") + + # Parse our input json files + parser = HfArgumentParser( + (ModelArguments, MultimodalDataTrainingArguments, OurTrainingArguments) + ) + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(JSON_FILE) + ) + + # Set random seed for reproducibility + set_seed(training_args.seed) + + # Create a tokenizer + tokenizer = AutoTokenizer.from_pretrained( + ( + model_args.tokenizer_name + if model_args.tokenizer_name + else model_args.model_name_or_path + ), + cache_dir=model_args.cache_dir, + ) + + # Load our feature processors + categorical_transformer = joblib.load(CATEGORICAL_TRANSFORMER_PATH) + numerical_transformer = joblib.load(NUMERICAL_TRANSFORMER_PATH) + + # Load our test set + data_df = pd.read_csv(os.path.join(data_args.data_path, "test.csv")) + + # Load and preprocess our test dataset + test_dataset = load_data( + data_df=data_df, + text_cols=data_args.column_info["text_cols"], + tokenizer=tokenizer, + label_col=data_args.column_info["label_col"], + label_list=data_args.column_info["label_list"], + categorical_cols=data_args.column_info["cat_cols"], + numerical_cols=data_args.column_info["num_cols"], + sep_text_token_str=( + tokenizer.sep_token + if not data_args.column_info["text_col_sep_token"] + else data_args.column_info["text_col_sep_token"] + ), + categorical_transformer=categorical_transformer, + numerical_transformer=numerical_transformer, + max_token_length=training_args.max_token_length, + debug=DEBUG, + debug_dataset_size=DEBUG_DATASET_SIZE, + ) + + task = data_args.task + # Regression tasks have only one "label" + if task == "regression": + num_labels = 1 + else: + num_labels = ( + len(np.unique(test_dataset.labels)) + if data_args.num_classes == -1 + else data_args.num_classes + ) + + # Setup configs + config = AutoConfig.from_pretrained( + MODEL_CONFIG_PATH, + cache_dir=model_args.cache_dir, + ) + tabular_config = TabularConfig( + num_labels=num_labels, + cat_feat_dim=( + test_dataset.cat_feats.shape[1] if test_dataset.cat_feats is not None else 0 + ), + numerical_feat_dim=( + test_dataset.numerical_feats.shape[1] + if test_dataset.numerical_feats is not None + else 0 + ), + **vars(data_args), + ) + config.tabular_config = tabular_config + + # Make model + model = AutoModelWithTabular.from_pretrained( + MODEL_PATH, + config=config, + cache_dir=model_args.cache_dir, + ) + + # Run inference + dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16) + model.eval() + all_labels = [] + all_preds = [] + with torch.no_grad(): + for batch in dataloader: + _, logits, classifier_outputs = model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + token_type_ids=batch["token_type_ids"], + cat_feats=batch["cat_feats"], + numerical_feats=batch["numerical_feats"], + ) + all_labels.append(batch["labels"]) + all_preds.append(logits.argmax(axis=1)) + + all_preds = torch.cat(all_preds) + all_labels = torch.cat(all_labels) + acc = torch.sum(all_preds == all_labels) / all_labels.shape[0] + print(f"Accuracy: {acc}")