This repository contains code for training and evaluating breed classification models using PyTorch. Two custom CNN models with a hierarchical local approach and a global approach are trained and compared with a pre-trained VGG16 model. The code also includes functionalities for generating heatmaps to visualize model activations.
- Python 3.x
- PyTorch
- torchvision
- pandas
- numpy
- matplotlib
- seaborn
-
Clone the repository:
git clone https://github.com/yourusername/breed-classifier.git
-
Loading Dataset: The e Oxford-IIIT-Pet dataset can be find inside kaggle . The dataset is loaded from the provided annotations file. Annotations are preprocessed to extract necessary information.
-
Data Augmentation and Transformation: Training, validation, and test datasets are created with appropriate transformations and augmentation.
-
Model Definition: Two custom CNN models are defined inside CatsDogsImageClassification and inside CatsDogsHierarchicalClassification. Model architectures include convolutional layers followed by fully connected layers for breed classification.
-
Training and Validation: Models are trained using the provided training and validation datasets. Training loss and accuracy are monitored to ensure model convergence.
-
Testing: Trained models are evaluated on the test dataset to measure their performance. Confusion matrices are generated to visualize classification results.
-
Visualization: Loss trends during training, as well as model activations (heatmaps), are visualized for analysis.
For more information please read the documentation present here