-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implementing the xor example: losses, optimizers, datasets training etc.
- Loading branch information
Showing
15 changed files
with
242 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
.vscode/ | ||
.vs/ | ||
build/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"configurations": [ | ||
{ | ||
"name": "x64-Debug", | ||
"generator": "Ninja", | ||
"configurationType": "Debug", | ||
"inheritEnvironments": [ "msvc_x64_x64" ], | ||
"buildRoot": "${projectDir}\\out\\build\\${name}", | ||
"installRoot": "${projectDir}\\out\\install\\${name}", | ||
"cmakeCommandArgs": "", | ||
"buildCommandArgs": "", | ||
"ctestCommandArgs": "" | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#pragma once | ||
#include <Eigen/Dense> | ||
#include <vector> | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Data | ||
{ | ||
class Dataset | ||
{ | ||
public: | ||
Dataset(const Eigen::MatrixXd &X, const Eigen::MatrixXd &Y); | ||
~Dataset() = default; | ||
|
||
using DataPair = std::pair<Eigen::MatrixXf, Eigen::MatrixXf>; | ||
|
||
class Iterator | ||
{ | ||
public: | ||
Iterator(const Eigen::MatrixXd &X, const Eigen::MatrixXd &Y, | ||
size_t index); | ||
Iterator &operator++(); | ||
bool operator!=(const Iterator &other) const; | ||
DataPair operator*() const; | ||
|
||
private: | ||
const Eigen::MatrixXd &m_X; | ||
const Eigen::MatrixXd &m_y; | ||
size_t m_index; | ||
}; | ||
|
||
Iterator begin() const; | ||
Iterator end() const; | ||
|
||
private: | ||
Eigen::MatrixXd m_X; | ||
Eigen::MatrixXd m_y; | ||
}; | ||
}; // namespace Data | ||
}; // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#pragma once | ||
|
||
#include "data/dataset.hpp" | ||
#include "loss/base.hpp" | ||
#include "optimizer/base.hpp" | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Graph | ||
{ | ||
class Trainable | ||
{ | ||
public: | ||
virtual void train(const Data::Dataset &dataset, | ||
const Optimizer::OptimizerBase &optimizer, | ||
const Loss::LossBase &loss, | ||
const uint32_t n_epochs) = 0; | ||
|
||
protected: | ||
Trainable() = default; | ||
virtual ~Trainable() = default; | ||
}; | ||
|
||
} // namespace Graph | ||
} // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#pragma once | ||
#include <Eigen/Dense> | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Loss | ||
{ | ||
class LossBase | ||
{ | ||
public: | ||
LossBase() = default; | ||
virtual ~LossBase() = default; | ||
|
||
virtual float compute(const Eigen::MatrixXf &y, | ||
const Eigen::MatrixXf &y_pred) const = 0; | ||
|
||
virtual void grad(Eigen::MatrixXf &out_grad, | ||
const Eigen::MatrixXf &y, | ||
const Eigen::MatrixXf &y_pred) const = 0; | ||
}; | ||
}; // namespace Loss | ||
}; // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,20 @@ | ||
#pragma once | ||
#include "loss/base.hpp" | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Loss | ||
{ | ||
class MSE | ||
class MSE : public LossBase | ||
{ | ||
MSE() = default; | ||
~MSE() = default; | ||
|
||
float compute(const Eigen::MatrixXf &y, | ||
const Eigen::MatrixXf &y_pred) const override; | ||
|
||
void grad(Eigen::MatrixXf &out_grad, const Eigen::MatrixXf &y, | ||
const Eigen::MatrixXf &y_pred) const override; | ||
}; | ||
}; // namespace Loss | ||
}; // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#pragma once | ||
#include "optimizer/base.hpp" | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Optimizer | ||
{ | ||
class SGD : public OptimizerBase | ||
{ | ||
public: | ||
SGD(); | ||
~SGD() = default; | ||
}; | ||
}; // namespace Optimizer | ||
}; // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#include "data/dataset.hpp" | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Data | ||
{ | ||
Dataset::Dataset(const Eigen::MatrixXd &X, const Eigen::MatrixXd &Y) | ||
: m_X(X), m_y(Y) | ||
{ | ||
} | ||
|
||
Dataset::Iterator::Iterator(const Eigen::MatrixXd &X, | ||
const Eigen::MatrixXd &Y, size_t index) | ||
: m_X(X), m_y(Y), m_index(index) | ||
{ | ||
} | ||
|
||
Dataset::Iterator &Dataset::Iterator::operator++() | ||
{ | ||
++m_index; | ||
return *this; | ||
} | ||
|
||
bool Dataset::Iterator::operator!=(const Dataset::Iterator &other) const | ||
{ | ||
return m_index != other.m_index; | ||
} | ||
|
||
Dataset::DataPair Dataset::Iterator::operator*() const | ||
{ | ||
Eigen::MatrixXf x = m_X.row(m_index).cast<float>(); | ||
Eigen::MatrixXf y = m_y.row(m_index).cast<float>(); | ||
return {x, y}; | ||
} | ||
|
||
Dataset::Iterator Dataset::begin() const | ||
{ | ||
return Iterator(m_X, m_y, 0); | ||
} | ||
|
||
Dataset::Iterator Dataset::end() const | ||
{ | ||
return Iterator(m_X, m_y, m_X.rows()); | ||
} | ||
} // namespace Data | ||
} // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#include "loss/mse.hpp" | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Loss | ||
{ | ||
float MSE::compute(const Eigen::MatrixXf &y, | ||
const Eigen::MatrixXf &y_pred) const | ||
{ | ||
return (y - y_pred).array().square().mean(); | ||
} | ||
|
||
void MSE::grad(Eigen::MatrixXf &out_grad, const Eigen::MatrixXf &y, | ||
const Eigen::MatrixXf &y_pred) const | ||
{ | ||
out_grad = 2.0f * (y_pred - y) / y.rows(); | ||
} | ||
} // namespace Loss | ||
} // namespace Tipousi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#include "optimizer/sgd.hpp" | ||
|
||
namespace Tipousi | ||
{ | ||
namespace Optimizer | ||
{ | ||
SGD::SGD() {} | ||
} // namespace Optimizer | ||
} // namespace Tipousi |