From cd7ab61531dd68e83ea78bd2fe5f0f946fda5d73 Mon Sep 17 00:00:00 2001 From: fm94 Date: Sun, 9 Jun 2024 19:00:19 +0200 Subject: [PATCH] add batchsize --- include/data/dataset.hpp | 11 ++++- include/graph/sequential.hpp | 2 +- include/graph/trainable.hpp | 2 +- src/data/dataset.cpp | 59 ++++++++++++++++++++---- src/graph/sequential.cpp | 3 +- tests/test_adder_inference.cpp | 3 +- tests/test_adder_inference_with_adam.cpp | 3 +- tests/test_xor_inference.cpp | 13 +++--- 8 files changed, 73 insertions(+), 23 deletions(-) diff --git a/include/data/dataset.hpp b/include/data/dataset.hpp index 972e61d..a466a2c 100644 --- a/include/data/dataset.hpp +++ b/include/data/dataset.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include namespace Tipousi @@ -10,16 +12,19 @@ namespace Tipousi class Dataset { public: - Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y); + Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y, + size_t batch_size); ~Dataset() = default; + void shuffle(); + using DataPair = std::pair; class Iterator { public: Iterator(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y, - size_t index); + size_t index, size_t batch_size); Iterator &operator++(); bool operator!=(const Iterator &other) const; DataPair operator*() const; @@ -28,6 +33,7 @@ namespace Tipousi const Eigen::MatrixXf &m_X; const Eigen::MatrixXf &m_y; size_t m_index; + size_t m_batch_size; }; Iterator begin() const; @@ -36,6 +42,7 @@ namespace Tipousi private: Eigen::MatrixXf m_X; Eigen::MatrixXf m_y; + size_t m_batch_size; }; }; // namespace Data }; // namespace Tipousi \ No newline at end of file diff --git a/include/graph/sequential.hpp b/include/graph/sequential.hpp index dd3f014..9270835 100644 --- a/include/graph/sequential.hpp +++ b/include/graph/sequential.hpp @@ -29,7 +29,7 @@ namespace Tipousi void forward(const Eigen::MatrixXf &in, Eigen::MatrixXf &out); void backward(Eigen::MatrixXf &initial_grads); - void train(const Data::Dataset &dataset, const Loss::LossBase &loss, + void train(Data::Dataset &dataset, const Loss::LossBase &loss, const uint32_t n_epochs) override; private: diff --git a/include/graph/trainable.hpp b/include/graph/trainable.hpp index b03aa0a..245e1d6 100644 --- a/include/graph/trainable.hpp +++ b/include/graph/trainable.hpp @@ -11,7 +11,7 @@ namespace Tipousi class Trainable { public: - virtual void train(const Data::Dataset &dataset, + virtual void train(Data::Dataset &dataset, const Loss::LossBase &loss, const uint32_t n_epochs) = 0; diff --git a/src/data/dataset.cpp b/src/data/dataset.cpp index 086ccf1..8fcd80b 100644 --- a/src/data/dataset.cpp +++ b/src/data/dataset.cpp @@ -4,43 +4,82 @@ namespace Tipousi { namespace Data { - Dataset::Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y) - : m_X(X), m_y(Y) + Dataset::Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y, + size_t batch_size) + : m_X(X), m_y(Y), m_batch_size(batch_size) { + if (batch_size > X.rows()) + { + throw std::invalid_argument( + "Batch size cannot be larger than the number of samples in " + "the dataset"); + } } Dataset::Iterator::Iterator(const Eigen::MatrixXf &X, - const Eigen::MatrixXf &Y, size_t index) - : m_X(X), m_y(Y), m_index(index) + const Eigen::MatrixXf &Y, size_t index, + size_t batch_size) + : m_X(X), m_y(Y), m_index(index), m_batch_size(batch_size) { } Dataset::Iterator &Dataset::Iterator::operator++() { - ++m_index; + m_index += m_batch_size; return *this; } bool Dataset::Iterator::operator!=(const Dataset::Iterator &other) const { - return m_index != other.m_index; + // return m_index != other.m_index; + // drop last incomplete batch + return m_index < other.m_index; } Dataset::DataPair Dataset::Iterator::operator*() const { - Eigen::MatrixXf x = m_X.row(m_index).cast(); - Eigen::MatrixXf y = m_y.row(m_index).cast(); + // doing min for each batch? check this + size_t endIndex = std::min(m_index + m_batch_size, + static_cast(m_X.rows())); + Eigen::MatrixXf x = + m_X.block(m_index, 0, endIndex - m_index, m_X.cols()); + Eigen::MatrixXf y = + m_y.block(m_index, 0, endIndex - m_index, m_y.cols()); return {x, y}; } Dataset::Iterator Dataset::begin() const { - return Iterator(m_X, m_y, 0); + return Iterator(m_X, m_y, 0, m_batch_size); } Dataset::Iterator Dataset::end() const { - return Iterator(m_X, m_y, m_X.rows()); + // return Iterator(m_X, m_y, m_X.rows(), m_batch_size); + // drop last incomplete batch + size_t endIndex = + (m_X.rows() / m_batch_size) * + m_batch_size; // Determine the last complete batch + return Iterator(m_X, m_y, endIndex, m_batch_size); + } + + void Dataset::shuffle() + { + std::vector indices(m_X.rows()); + std::iota(indices.begin(), indices.end(), 0); + std::random_device rd; + std::mt19937 g(rd()); + std::shuffle(indices.begin(), indices.end(), g); + + Eigen::PermutationMatrix perm( + indices.size()); + for (size_t i = 0; i < indices.size(); ++i) + { + perm.indices()[i] = indices[i]; + } + + m_X = perm * m_X; + m_y = perm * m_y; } } // namespace Data } // namespace Tipousi \ No newline at end of file diff --git a/src/graph/sequential.cpp b/src/graph/sequential.cpp index 1fe09fb..cfdc0d8 100644 --- a/src/graph/sequential.cpp +++ b/src/graph/sequential.cpp @@ -78,7 +78,7 @@ namespace Tipousi } } - void Sequential::train(const Data::Dataset &dataset, + void Sequential::train(Data::Dataset &dataset, const Loss::LossBase &loss_func, const uint32_t n_epochs) { @@ -100,6 +100,7 @@ namespace Tipousi backward(out_grad); counter++; } + dataset.shuffle(); std::cout << "Epoch: " << i << ", Loss: " << total_loss / counter << std::endl; } diff --git a/tests/test_adder_inference.cpp b/tests/test_adder_inference.cpp index ce24342..f8494ae 100644 --- a/tests/test_adder_inference.cpp +++ b/tests/test_adder_inference.cpp @@ -48,7 +48,8 @@ TEST(SimpleNetTest, AdderTest) Y << 2, 3, 3, 4; // create dataset - Dataset dataset(X, Y); + size_t batch_size = 1; + Dataset dataset(X, Y, batch_size); // define the loss MSE mse; diff --git a/tests/test_adder_inference_with_adam.cpp b/tests/test_adder_inference_with_adam.cpp index ca073d0..d05362f 100644 --- a/tests/test_adder_inference_with_adam.cpp +++ b/tests/test_adder_inference_with_adam.cpp @@ -48,7 +48,8 @@ TEST(SimpleNetTest, AdderTestWithAdam) Y << 2, 3, 3, 4; // create dataset - Dataset dataset(X, Y); + size_t batch_size = 1; + Dataset dataset(X, Y, batch_size); // define the loss MSE mse; diff --git a/tests/test_xor_inference.cpp b/tests/test_xor_inference.cpp index 4eb3bf0..926ff31 100644 --- a/tests/test_xor_inference.cpp +++ b/tests/test_xor_inference.cpp @@ -5,7 +5,7 @@ #include "graph/sequential.hpp" #include "layer/dense.hpp" #include "loss/mse.hpp" -#include "optimizer/sgd.hpp" +#include "optimizer/adam.hpp" #include #include #include @@ -36,11 +36,11 @@ TEST(SimpleNetTest, XORTest) node3->add_input(node2); node4->add_input(node3); - float learning_rate{0.5f}; - SGD sgd(learning_rate); + float learning_rate{0.08f}; + Adam optimizer(learning_rate); // create the graph (pass input and output nodes) - Sequential net(node1, node4, sgd); + Sequential net(node1, node4, optimizer); // test inference Eigen::MatrixXf X(4, 2); @@ -51,12 +51,13 @@ TEST(SimpleNetTest, XORTest) Y << 0, 1, 1, 0; // create dataset - Dataset dataset(X, Y); + size_t batch_size = 3; + Dataset dataset(X, Y, batch_size); // define the loss MSE mse; auto start = std::chrono::high_resolution_clock::now(); - net.train(dataset, mse, 200); + net.train(dataset, mse, 100); auto end = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(end - start)