Skip to content

Commit

Permalink
add batchsize
Browse files Browse the repository at this point in the history
  • Loading branch information
fm94 committed Jun 9, 2024
1 parent bdd2602 commit cd7ab61
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 23 deletions.
11 changes: 9 additions & 2 deletions include/data/dataset.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once

#include <Eigen/Dense>
#include <algorithm>
#include <random>
#include <vector>

namespace Tipousi
Expand All @@ -10,16 +12,19 @@ namespace Tipousi
class Dataset
{
public:
Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y);
Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y,
size_t batch_size);
~Dataset() = default;

void shuffle();

using DataPair = std::pair<Eigen::MatrixXf, Eigen::MatrixXf>;

class Iterator
{
public:
Iterator(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y,
size_t index);
size_t index, size_t batch_size);
Iterator &operator++();
bool operator!=(const Iterator &other) const;
DataPair operator*() const;
Expand All @@ -28,6 +33,7 @@ namespace Tipousi
const Eigen::MatrixXf &m_X;
const Eigen::MatrixXf &m_y;
size_t m_index;
size_t m_batch_size;
};

Iterator begin() const;
Expand All @@ -36,6 +42,7 @@ namespace Tipousi
private:
Eigen::MatrixXf m_X;
Eigen::MatrixXf m_y;
size_t m_batch_size;
};
}; // namespace Data
}; // namespace Tipousi
2 changes: 1 addition & 1 deletion include/graph/sequential.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace Tipousi
void forward(const Eigen::MatrixXf &in, Eigen::MatrixXf &out);
void backward(Eigen::MatrixXf &initial_grads);

void train(const Data::Dataset &dataset, const Loss::LossBase &loss,
void train(Data::Dataset &dataset, const Loss::LossBase &loss,
const uint32_t n_epochs) override;

private:
Expand Down
2 changes: 1 addition & 1 deletion include/graph/trainable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace Tipousi
class Trainable
{
public:
virtual void train(const Data::Dataset &dataset,
virtual void train(Data::Dataset &dataset,
const Loss::LossBase &loss,
const uint32_t n_epochs) = 0;

Expand Down
59 changes: 49 additions & 10 deletions src/data/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,82 @@ namespace Tipousi
{
namespace Data
{
Dataset::Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y)
: m_X(X), m_y(Y)
Dataset::Dataset(const Eigen::MatrixXf &X, const Eigen::MatrixXf &Y,
size_t batch_size)
: m_X(X), m_y(Y), m_batch_size(batch_size)
{
if (batch_size > X.rows())
{
throw std::invalid_argument(
"Batch size cannot be larger than the number of samples in "
"the dataset");
}
}

Dataset::Iterator::Iterator(const Eigen::MatrixXf &X,
const Eigen::MatrixXf &Y, size_t index)
: m_X(X), m_y(Y), m_index(index)
const Eigen::MatrixXf &Y, size_t index,
size_t batch_size)
: m_X(X), m_y(Y), m_index(index), m_batch_size(batch_size)
{
}

Dataset::Iterator &Dataset::Iterator::operator++()
{
++m_index;
m_index += m_batch_size;
return *this;
}

bool Dataset::Iterator::operator!=(const Dataset::Iterator &other) const
{
return m_index != other.m_index;
// return m_index != other.m_index;
// drop last incomplete batch
return m_index < other.m_index;
}

Dataset::DataPair Dataset::Iterator::operator*() const
{
Eigen::MatrixXf x = m_X.row(m_index).cast<float>();
Eigen::MatrixXf y = m_y.row(m_index).cast<float>();
// doing min for each batch? check this
size_t endIndex = std::min(m_index + m_batch_size,
static_cast<size_t>(m_X.rows()));
Eigen::MatrixXf x =
m_X.block(m_index, 0, endIndex - m_index, m_X.cols());
Eigen::MatrixXf y =
m_y.block(m_index, 0, endIndex - m_index, m_y.cols());
return {x, y};
}

Dataset::Iterator Dataset::begin() const
{
return Iterator(m_X, m_y, 0);
return Iterator(m_X, m_y, 0, m_batch_size);
}

Dataset::Iterator Dataset::end() const
{
return Iterator(m_X, m_y, m_X.rows());
// return Iterator(m_X, m_y, m_X.rows(), m_batch_size);
// drop last incomplete batch
size_t endIndex =
(m_X.rows() / m_batch_size) *
m_batch_size; // Determine the last complete batch
return Iterator(m_X, m_y, endIndex, m_batch_size);
}

void Dataset::shuffle()
{
std::vector<int> indices(m_X.rows());
std::iota(indices.begin(), indices.end(), 0);
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(indices.begin(), indices.end(), g);

Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> perm(
indices.size());
for (size_t i = 0; i < indices.size(); ++i)
{
perm.indices()[i] = indices[i];
}

m_X = perm * m_X;
m_y = perm * m_y;
}
} // namespace Data
} // namespace Tipousi
3 changes: 2 additions & 1 deletion src/graph/sequential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ namespace Tipousi
}
}

void Sequential::train(const Data::Dataset &dataset,
void Sequential::train(Data::Dataset &dataset,
const Loss::LossBase &loss_func,
const uint32_t n_epochs)
{
Expand All @@ -100,6 +100,7 @@ namespace Tipousi
backward(out_grad);
counter++;
}
dataset.shuffle();
std::cout << "Epoch: " << i
<< ", Loss: " << total_loss / counter << std::endl;
}
Expand Down
3 changes: 2 additions & 1 deletion tests/test_adder_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ TEST(SimpleNetTest, AdderTest)
Y << 2, 3, 3, 4;

// create dataset
Dataset dataset(X, Y);
size_t batch_size = 1;
Dataset dataset(X, Y, batch_size);

// define the loss
MSE mse;
Expand Down
3 changes: 2 additions & 1 deletion tests/test_adder_inference_with_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ TEST(SimpleNetTest, AdderTestWithAdam)
Y << 2, 3, 3, 4;

// create dataset
Dataset dataset(X, Y);
size_t batch_size = 1;
Dataset dataset(X, Y, batch_size);

// define the loss
MSE mse;
Expand Down
13 changes: 7 additions & 6 deletions tests/test_xor_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "graph/sequential.hpp"
#include "layer/dense.hpp"
#include "loss/mse.hpp"
#include "optimizer/sgd.hpp"
#include "optimizer/adam.hpp"
#include <chrono>
#include <gtest/gtest.h>
#include <iostream>
Expand Down Expand Up @@ -36,11 +36,11 @@ TEST(SimpleNetTest, XORTest)
node3->add_input(node2);
node4->add_input(node3);

float learning_rate{0.5f};
SGD sgd(learning_rate);
float learning_rate{0.08f};
Adam optimizer(learning_rate);

// create the graph (pass input and output nodes)
Sequential net(node1, node4, sgd);
Sequential net(node1, node4, optimizer);

// test inference
Eigen::MatrixXf X(4, 2);
Expand All @@ -51,12 +51,13 @@ TEST(SimpleNetTest, XORTest)
Y << 0, 1, 1, 0;

// create dataset
Dataset dataset(X, Y);
size_t batch_size = 3;
Dataset dataset(X, Y, batch_size);

// define the loss
MSE mse;
auto start = std::chrono::high_resolution_clock::now();
net.train(dataset, mse, 200);
net.train(dataset, mse, 100);
auto end = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
Expand Down

0 comments on commit cd7ab61

Please sign in to comment.