Skip to content

Commit

Permalink
refactor: make functions for building text/num/cat features
Browse files Browse the repository at this point in the history
  • Loading branch information
akashsara committed Sep 6, 2024
1 parent c233acb commit f4fc678
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions multimodal_transformers/data/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def load_train_val_test_helper(
join(output_dir, "numerical_transformer.pkl"),
)
if categorical_transformer:
joblib.dump(categorical_transformer, join(output_dir, "categorical_transformer.pkl"))
joblib.dump(
categorical_transformer, join(output_dir, "categorical_transformer.pkl")
)
torch.save(train_dataset, join(output_dir, "train_data.pt"))
torch.save(test_dataset, join(output_dir, "test_data.pt"))
if val_dataset:
Expand All @@ -389,6 +391,46 @@ def load_train_val_test_helper(
return train_dataset, val_dataset, test_dataset


def build_categorical_features(data_df, categorical_cols, categorical_transformer):
if len(categorical_cols) > 0:
# Find columns in the dataset that are in categorical_cols
categorical_cols_func = convert_to_func(categorical_cols)
categorical_cols = get_matching_cols(data_df, categorical_cols_func)
if categorical_transformer is not None:
return categorical_transformer.transform(data_df[categorical_cols])
else:
return data_df[categorical_cols]
else:
return None


def build_numerical_features(data_df, numerical_cols, numerical_transformer):
if len(numerical_cols) > 0:
# Find columns in the dataset that are in numerical_cols
numerical_cols_func = convert_to_func(numerical_cols)
numerical_cols = get_matching_cols(data_df, numerical_cols_func)
if numerical_transformer is not None:
return numerical_transformer.transform(data_df[numerical_cols])
else:
return data_df[numerical_cols]
else:
return None


def build_text_features(
data_df, text_cols, empty_text_values, replace_empty_text, sep_text_token_str
):
text_cols_func = convert_to_func(text_cols)
agg_func = partial(agg_text_columns_func, empty_text_values, replace_empty_text)
text_cols = get_matching_cols(data_df, text_cols_func)
logger.info(f"Text columns: {text_cols}")
texts_list = data_df[text_cols].agg(agg_func, axis=1).tolist()
for i, text in enumerate(texts_list):
texts_list[i] = f" {sep_text_token_str} ".join(text)
logger.info(f"Raw text example: {texts_list[0]}")
return texts_list


def load_data(
data_df,
text_cols,
Expand Down Expand Up @@ -453,44 +495,25 @@ def load_data(
if empty_text_values is None:
empty_text_values = ["nan", "None"]

text_cols_func = convert_to_func(text_cols)
categorical_cols_func = convert_to_func(categorical_cols)
numerical_cols_func = convert_to_func(numerical_cols)

# Build categorical features
if len(categorical_cols) > 0:
# Find columns in the dataset that are in categorical_cols
categorical_cols = get_matching_cols(data_df, categorical_cols_func)
if categorical_transformer is not None:
categorical_feats = categorical_transformer.transform(
data_df[categorical_cols]
)
else:
categorical_feats = data_df[categorical_cols]
else:
categorical_feats = None

categorical_feats = build_categorical_features(
data_df=data_df,
categorical_cols=categorical_cols,
categorical_transformer=categorical_transformer,
)
# Build numerical features
if len(numerical_cols) > 0:
# Find columns in the dataset that are in numerical_cols
numerical_cols = get_matching_cols(data_df, numerical_cols_func)
if numerical_transformer is not None:
numerical_feats = numerical_transformer.transform(data_df[numerical_cols])
else:
numerical_feats = data_df[numerical_cols]
else:
numerical_feats = None
numerical_feats = build_numerical_features(
data_df=data_df,
numerical_cols=numerical_cols,
numerical_transformer=numerical_transformer,
)

# Build text features
agg_func = partial(agg_text_columns_func, empty_text_values, replace_empty_text)
texts_cols = get_matching_cols(data_df, text_cols_func)
logger.info(f"Text columns: {texts_cols}")
texts_list = data_df[texts_cols].agg(agg_func, axis=1).tolist()
for i, text in enumerate(texts_list):
texts_list[i] = f" {sep_text_token_str} ".join(text)
logger.info(f"Raw text example: {texts_list[0]}")
texts_list = build_text_features(
data_df, text_cols, empty_text_values, replace_empty_text, sep_text_token_str
)

# Create tokenizer
# Create tokenized text features
hf_model_text_input = tokenizer(
texts_list, padding=True, truncation=True, max_length=max_token_length
)
Expand Down

0 comments on commit f4fc678

Please sign in to comment.