Skip to content

Commit

Permalink
Actually save the transformer weights
Browse files Browse the repository at this point in the history
  • Loading branch information
woodRock committed Sep 24, 2024
1 parent d75e9d3 commit 5d766d4
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 222 deletions.
Binary file modified code/transformer/figures/decoder_attention_map.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified code/transformer/figures/encoder_attention_map.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified code/transformer/figures/model_accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified code/transformer/figures/train_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified code/transformer/figures/validation_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
444 changes: 225 additions & 219 deletions code/transformer/logs/results_0.log

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion code/transformer/multi-task.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def main():
num_epochs=args.epochs,
patience=args.early_stopping
)

# Save the model to disk.
torch.save(model.state_dict(), args.file_path)

# finish measuring how long training took
endTime = time.time()
Expand All @@ -141,4 +144,4 @@ def main():


if __name__ == "__main__":
main()
main()
5 changes: 3 additions & 2 deletions code/transformer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def random_augmentation(
return xs, ys

def load_from_file(
path: Iterable = ["~/", "Desktop", "fishy-business", "data", "REIMS_data.xlsx"]
# path: Iterable = ["~/", "Desktop", "fishy-business", "data", "REIMS_data.xlsx"]
path: Iterable = ["/vol","ecrg-solar","woodj4","fishy-business","data", "REIMS_data.xlsx"]
) -> pd.DataFrame:
""" Load the dataset from a file path.
Expand Down Expand Up @@ -351,4 +352,4 @@ def preprocess_dataset(
is_data_augmentation=is_data_augmentation,
batch_size=batch_size
)
return train_loader, val_loader, train_steps, val_steps, data
return train_loader, val_loader, train_steps, val_steps, data

0 comments on commit 5d766d4

Please sign in to comment.