Skip to content

Commit

Permalink
VAE weighted cross entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
woodRock committed Oct 3, 2024
1 parent edde103 commit c27681e
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 167 deletions.
Binary file modified code/vae/__pycache__/train.cpython-310.pyc
Binary file not shown.
Binary file modified code/vae/__pycache__/util.cpython-310.pyc
Binary file not shown.
Binary file modified code/vae/__pycache__/vae.cpython-310.pyc
Binary file not shown.
Binary file modified code/vae/figures/model_accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified code/vae/figures/train_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified code/vae/figures/validation_confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
187 changes: 21 additions & 166 deletions code/vae/logs/results_0.log

Large diffs are not rendered by default.

37 changes: 36 additions & 1 deletion code/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from util import preprocess_dataset
from train import train_model, evaluate_model
from vae import VAE
Expand Down Expand Up @@ -66,6 +67,35 @@ def setup_logging(args): # Logging output to a file.
logging.basicConfig(filename=output, level=logging.INFO, filemode='w')
return logger


def calculate_class_weights(train_loader: DataLoader) -> torch.Tensor:
"""
Calculate the weights for each class based on their frequency in the dataset.
Args:
train_loader (DataLoader): The training data loader.
Returns:
torch.Tensor: A tensor of weights for each class.
"""
class_counts = {}
total_samples = 0

for _, labels in train_loader:
for label in labels:
class_label = label.argmax().item()
class_counts[class_label] = class_counts.get(class_label, 0) + 1
total_samples += 1

class_weights = []
for i in range(len(class_counts)):
class_weights.append(1.0 / class_counts[i])

class_weights = torch.FloatTensor(class_weights)
class_weights = class_weights / class_weights.sum() * len(class_counts)

return class_weights

def main():
args = parse_arguments()
logger = setup_logging(args)
Expand All @@ -87,6 +117,9 @@ def main():
is_pre_train=False
)

# Calculate class weights
class_weights = calculate_class_weights(train_loader)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Expand All @@ -98,8 +131,10 @@ def main():
dropout=args.dropout
)
model = model.to(device)
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
# Initialize the weighted loss function
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=args.label_smoothing)
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

model = train_model(
Expand Down

0 comments on commit c27681e

Please sign in to comment.