Skip to content

Commit

Permalink
feat: add inference example
Browse files Browse the repository at this point in the history
  • Loading branch information
akashsara committed Sep 6, 2024
1 parent f4fc678 commit 58e722b
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions examples/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
This is an example of how to use the toolkit to run inference.
We use saved feature processors (generated by code in the tests folder).
This code assumes you have a dataframe of datapoints and batches them up.
"""

import os
import sys

sys.path.append("./")

import joblib
import numpy as np
import pandas as pd
import torch
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, set_seed

from multimodal_exp_args import (
ModelArguments,
MultimodalDataTrainingArguments,
OurTrainingArguments,
)
from multimodal_transformers.data import load_data
from multimodal_transformers.model import AutoModelWithTabular, TabularConfig

if __name__ == "__main__":
DEBUG = True
DEBUG_DATASET_SIZE = 50
JSON_FILE = "./tests/test_airbnb.json"
MODEL_SAVE_DIR = "./logs_airbnb/bertmultilingual_gating_on_cat_and_num_feats_then_sum_full_model_lr_3e-3/"
NUMERICAL_TRANSFORMER_PATH = os.path.join(
MODEL_SAVE_DIR, "numerical_transformer.pkl"
)
CATEGORICAL_TRANSFORMER_PATH = os.path.join(
MODEL_SAVE_DIR, "categorical_transformer.pkl"
)
MODEL_CONFIG_PATH = os.path.join(MODEL_SAVE_DIR, "config.json")
MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "model.safetensors")

# Parse our input json files
parser = HfArgumentParser(
(ModelArguments, MultimodalDataTrainingArguments, OurTrainingArguments)
)
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(JSON_FILE)
)

# Set random seed for reproducibility
set_seed(training_args.seed)

# Create a tokenizer
tokenizer = AutoTokenizer.from_pretrained(
(
model_args.tokenizer_name
if model_args.tokenizer_name
else model_args.model_name_or_path
),
cache_dir=model_args.cache_dir,
)

# Load our feature processors
categorical_transformer = joblib.load(CATEGORICAL_TRANSFORMER_PATH)
numerical_transformer = joblib.load(NUMERICAL_TRANSFORMER_PATH)

# Load our test set
data_df = pd.read_csv(os.path.join(data_args.data_path, "test.csv"))

# Load and preprocess our test dataset
test_dataset = load_data(
data_df=data_df,
text_cols=data_args.column_info["text_cols"],
tokenizer=tokenizer,
label_col=data_args.column_info["label_col"],
label_list=data_args.column_info["label_list"],
categorical_cols=data_args.column_info["cat_cols"],
numerical_cols=data_args.column_info["num_cols"],
sep_text_token_str=(
tokenizer.sep_token
if not data_args.column_info["text_col_sep_token"]
else data_args.column_info["text_col_sep_token"]
),
categorical_transformer=categorical_transformer,
numerical_transformer=numerical_transformer,
max_token_length=training_args.max_token_length,
debug=DEBUG,
debug_dataset_size=DEBUG_DATASET_SIZE,
)

task = data_args.task
# Regression tasks have only one "label"
if task == "regression":
num_labels = 1
else:
num_labels = (
len(np.unique(test_dataset.labels))
if data_args.num_classes == -1
else data_args.num_classes
)

# Setup configs
config = AutoConfig.from_pretrained(
MODEL_CONFIG_PATH,
cache_dir=model_args.cache_dir,
)
tabular_config = TabularConfig(
num_labels=num_labels,
cat_feat_dim=(
test_dataset.cat_feats.shape[1] if test_dataset.cat_feats is not None else 0
),
numerical_feat_dim=(
test_dataset.numerical_feats.shape[1]
if test_dataset.numerical_feats is not None
else 0
),
**vars(data_args),
)
config.tabular_config = tabular_config

# Make model
model = AutoModelWithTabular.from_pretrained(
MODEL_PATH,
config=config,
cache_dir=model_args.cache_dir,
)

# Run inference
dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16)
model.eval()
all_labels = []
all_preds = []
with torch.no_grad():
for batch in dataloader:
_, logits, classifier_outputs = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"],
cat_feats=batch["cat_feats"],
numerical_feats=batch["numerical_feats"],
)
all_labels.append(batch["labels"])
all_preds.append(logits.argmax(axis=1))

all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
acc = torch.sum(all_preds == all_labels) / all_labels.shape[0]
print(f"Accuracy: {acc}")

0 comments on commit 58e722b

Please sign in to comment.