diff --git a/multimodal_transformers/data/load_data.py b/multimodal_transformers/data/load_data.py index 30b8ef3..c621e8d 100644 --- a/multimodal_transformers/data/load_data.py +++ b/multimodal_transformers/data/load_data.py @@ -380,7 +380,9 @@ def load_train_val_test_helper( join(output_dir, "numerical_transformer.pkl"), ) if categorical_transformer: - joblib.dump(categorical_transformer, join(output_dir, "categorical_transformer.pkl")) + joblib.dump( + categorical_transformer, join(output_dir, "categorical_transformer.pkl") + ) torch.save(train_dataset, join(output_dir, "train_data.pt")) torch.save(test_dataset, join(output_dir, "test_data.pt")) if val_dataset: @@ -389,6 +391,46 @@ def load_train_val_test_helper( return train_dataset, val_dataset, test_dataset +def build_categorical_features(data_df, categorical_cols, categorical_transformer): + if len(categorical_cols) > 0: + # Find columns in the dataset that are in categorical_cols + categorical_cols_func = convert_to_func(categorical_cols) + categorical_cols = get_matching_cols(data_df, categorical_cols_func) + if categorical_transformer is not None: + return categorical_transformer.transform(data_df[categorical_cols]) + else: + return data_df[categorical_cols] + else: + return None + + +def build_numerical_features(data_df, numerical_cols, numerical_transformer): + if len(numerical_cols) > 0: + # Find columns in the dataset that are in numerical_cols + numerical_cols_func = convert_to_func(numerical_cols) + numerical_cols = get_matching_cols(data_df, numerical_cols_func) + if numerical_transformer is not None: + return numerical_transformer.transform(data_df[numerical_cols]) + else: + return data_df[numerical_cols] + else: + return None + + +def build_text_features( + data_df, text_cols, empty_text_values, replace_empty_text, sep_text_token_str +): + text_cols_func = convert_to_func(text_cols) + agg_func = partial(agg_text_columns_func, empty_text_values, replace_empty_text) + text_cols = get_matching_cols(data_df, text_cols_func) + logger.info(f"Text columns: {text_cols}") + texts_list = data_df[text_cols].agg(agg_func, axis=1).tolist() + for i, text in enumerate(texts_list): + texts_list[i] = f" {sep_text_token_str} ".join(text) + logger.info(f"Raw text example: {texts_list[0]}") + return texts_list + + def load_data( data_df, text_cols, @@ -453,44 +495,25 @@ def load_data( if empty_text_values is None: empty_text_values = ["nan", "None"] - text_cols_func = convert_to_func(text_cols) - categorical_cols_func = convert_to_func(categorical_cols) - numerical_cols_func = convert_to_func(numerical_cols) - # Build categorical features - if len(categorical_cols) > 0: - # Find columns in the dataset that are in categorical_cols - categorical_cols = get_matching_cols(data_df, categorical_cols_func) - if categorical_transformer is not None: - categorical_feats = categorical_transformer.transform( - data_df[categorical_cols] - ) - else: - categorical_feats = data_df[categorical_cols] - else: - categorical_feats = None - + categorical_feats = build_categorical_features( + data_df=data_df, + categorical_cols=categorical_cols, + categorical_transformer=categorical_transformer, + ) # Build numerical features - if len(numerical_cols) > 0: - # Find columns in the dataset that are in numerical_cols - numerical_cols = get_matching_cols(data_df, numerical_cols_func) - if numerical_transformer is not None: - numerical_feats = numerical_transformer.transform(data_df[numerical_cols]) - else: - numerical_feats = data_df[numerical_cols] - else: - numerical_feats = None + numerical_feats = build_numerical_features( + data_df=data_df, + numerical_cols=numerical_cols, + numerical_transformer=numerical_transformer, + ) # Build text features - agg_func = partial(agg_text_columns_func, empty_text_values, replace_empty_text) - texts_cols = get_matching_cols(data_df, text_cols_func) - logger.info(f"Text columns: {texts_cols}") - texts_list = data_df[texts_cols].agg(agg_func, axis=1).tolist() - for i, text in enumerate(texts_list): - texts_list[i] = f" {sep_text_token_str} ".join(text) - logger.info(f"Raw text example: {texts_list[0]}") + texts_list = build_text_features( + data_df, text_cols, empty_text_values, replace_empty_text, sep_text_token_str + ) - # Create tokenizer + # Create tokenized text features hf_model_text_input = tokenizer( texts_list, padding=True, truncation=True, max_length=max_token_length )