-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#include <torch/torch.h> | ||
|
||
// Define a new Module. | ||
struct Net : torch::nn::Module { | ||
Net() { | ||
// Construct and register two Linear submodules. | ||
fc1 = register_module("fc1", torch::nn::Linear(784, 64)); | ||
fc2 = register_module("fc2", torch::nn::Linear(64, 32)); | ||
fc3 = register_module("fc3", torch::nn::Linear(32, 10)); | ||
} | ||
|
||
// Implement the Net's algorithm. | ||
torch::Tensor forward(torch::Tensor x) { | ||
// Use one of many tensor manipulation functions. | ||
x = torch::relu(fc1->forward(x.reshape({x.size(0), 784}))); | ||
x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training()); | ||
x = torch::relu(fc2->forward(x)); | ||
x = torch::log_softmax(fc3->forward(x), /*dim=*/1); | ||
return x; | ||
} | ||
|
||
// Use one of many "standard library" modules. | ||
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; | ||
}; | ||
|
||
int main() { | ||
// Create a new Net. | ||
auto net = std::make_shared<Net>(); | ||
|
||
// Create a multi-threaded data loader for the MNIST dataset. | ||
auto data_loader = torch::data::make_data_loader( | ||
torch::data::datasets::MNIST("./data").map( | ||
torch::data::transforms::Stack<>()), | ||
/*batch_size=*/64); | ||
|
||
// Instantiate an SGD optimization algorithm to update our Net's parameters. | ||
torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01); | ||
|
||
for (size_t epoch = 1; epoch <= 10; ++epoch) { | ||
size_t batch_index = 0; | ||
// Iterate the data loader to yield batches from the dataset. | ||
for (auto& batch : *data_loader) { | ||
// Reset gradients. | ||
optimizer.zero_grad(); | ||
// Execute the model on the input data. | ||
torch::Tensor prediction = net->forward(batch.data); | ||
// Compute a loss value to judge the prediction of our model. | ||
torch::Tensor loss = torch::nll_loss(prediction, batch.target); | ||
// Compute gradients of the loss w.r.t. the parameters of our model. | ||
loss.backward(); | ||
// Update the parameters based on the calculated gradients. | ||
optimizer.step(); | ||
// Output the loss and checkpoint every 100 batches. | ||
if (++batch_index % 100 == 0) { | ||
std::cout << "Epoch: " << epoch << " | Batch: " << batch_index | ||
<< " | Loss: " << loss.item<float>() << std::endl; | ||
// Serialize your model periodically as a checkpoint. | ||
torch::save(net, "net.pt"); | ||
} | ||
} | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import torch | ||
import math | ||
import sys | ||
|
||
|
||
dtype = torch.float | ||
device = torch.device(sys.argv[1]) | ||
|
||
# Create random input and output data | ||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) | ||
y = torch.sin(x) | ||
|
||
# Randomly initialize weights | ||
a = torch.randn((), device=device, dtype=dtype) | ||
b = torch.randn((), device=device, dtype=dtype) | ||
c = torch.randn((), device=device, dtype=dtype) | ||
d = torch.randn((), device=device, dtype=dtype) | ||
|
||
learning_rate = 1e-6 | ||
for t in range(2000): | ||
# Forward pass: compute predicted y | ||
y_pred = a + b * x + c * x ** 2 + d * x ** 3 | ||
|
||
# Compute and print loss | ||
loss = (y_pred - y).pow(2).sum().item() | ||
if t % 100 == 99: | ||
print(t, loss) | ||
|
||
# Backprop to compute gradients of a, b, c, d with respect to loss | ||
grad_y_pred = 2.0 * (y_pred - y) | ||
grad_a = grad_y_pred.sum() | ||
grad_b = (grad_y_pred * x).sum() | ||
grad_c = (grad_y_pred * x ** 2).sum() | ||
grad_d = (grad_y_pred * x ** 3).sum() | ||
|
||
# Update weights using gradient descent | ||
a -= learning_rate * grad_a | ||
b -= learning_rate * grad_b | ||
c -= learning_rate * grad_c | ||
d -= learning_rate * grad_d | ||
|
||
|
||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import torch | ||
import math | ||
|
||
|
||
dtype = torch.float | ||
device = torch.device("cuda:0") | ||
|
||
# Create random input and output data | ||
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) | ||
y = torch.sin(x) | ||
|
||
# Randomly initialize weights | ||
a = torch.randn((), device=device, dtype=dtype) | ||
b = torch.randn((), device=device, dtype=dtype) | ||
c = torch.randn((), device=device, dtype=dtype) | ||
d = torch.randn((), device=device, dtype=dtype) | ||
|
||
learning_rate = 1e-6 | ||
for t in range(2000): | ||
# Forward pass: compute predicted y | ||
y_pred = a + b * x + c * x ** 2 + d * x ** 3 | ||
|
||
# Compute and print loss | ||
loss = (y_pred - y).pow(2).sum().item() | ||
if t % 100 == 99: | ||
print(t, loss) | ||
|
||
# Backprop to compute gradients of a, b, c, d with respect to loss | ||
grad_y_pred = 2.0 * (y_pred - y) | ||
grad_a = grad_y_pred.sum() | ||
grad_b = (grad_y_pred * x).sum() | ||
grad_c = (grad_y_pred * x ** 2).sum() | ||
grad_d = (grad_y_pred * x ** 3).sum() | ||
|
||
# Update weights using gradient descent | ||
a -= learning_rate * grad_a | ||
b -= learning_rate * grad_b | ||
c -= learning_rate * grad_c | ||
d -= learning_rate * grad_d | ||
|
||
|
||
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') | ||
|