Skip to content

Commit

Permalink
coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
uecker committed Sep 23, 2022
1 parent 3f6ebb1 commit b6967ac
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/mnist.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static nn_t network_mnist_create(const long odims[2], const long idims[3], enum

const struct initializer_s* init = NULL; // fallback to default initializer

long pool_size[] = {2, 2, 1};
long pool_size[] = { 2, 2, 1 };

bool conv = false; // we usecross correlation not convolution, i.e. as usual in deep-learning the convolution kernels are not flipped

Expand Down Expand Up @@ -119,7 +119,6 @@ static const char help_str[] = "Trains or applies a MNIST network.\nThis network

int main_mnist(int argc, char* argv[argc])
{

bool apply = false;
bool train = false;

Expand Down Expand Up @@ -152,11 +151,12 @@ int main_mnist(int argc, char* argv[argc])

num_init_gpu();
cuda_use_global_memory();
}

else
} else
#endif
{
num_init();
}


if (apply && train)
Expand All @@ -176,8 +176,8 @@ int main_mnist(int argc, char* argv[argc])
long dims_out[NO];
complex float* out = load_cfl(filename_out, NO, dims_out);

long bdims_in[] = { dims_in[0], dims_in[1], Nb};
long bdims_out[] = { dims_out[0], Nb};
long bdims_in[] = { dims_in[0], dims_in[1], Nb };
long bdims_out[] = { dims_out[0], Nb };

long Nt = dims_out[1]; //dataset size
assert(Nt == dims_in[2]);
Expand All @@ -199,8 +199,7 @@ int main_mnist(int argc, char* argv[argc])
(const long*[2]){ bdims_out, bdims_in},
(const long*[2]){ dims_out, dims_in },
(const complex float*[2]){ out, in },
0, BATCH_GEN_SHUFFLE_DATA, 123
);
0, BATCH_GEN_SHUFFLE_DATA, 123);

//setup for iter algorithm
int II = nn_get_nr_in_args(train_op);
Expand Down Expand Up @@ -232,6 +231,7 @@ int main_mnist(int argc, char* argv[argc])
unmap_cfl(NO, dims_out, out);
}


if (apply) {

long dims_out[] = { 10, dims_in[2] };
Expand All @@ -247,11 +247,12 @@ int main_mnist(int argc, char* argv[argc])
net = nn_get_wo_weights_F(net, weights, false); //set inputs corresponding to weights to the loaded weights

nlop_generic_apply_sameplace(nn_get_nlop(net),
1, (int[1]){ 2 }, (const long*[1]){ dims_out } , (complex float* [1]){ out },
1, (int[1]){ 3 }, (const long*[1]){ dims_in } , (const complex float*[1]){ in },
1, (int[1]){ 2 }, (const long*[1]){ dims_out }, (complex float* [1]){ out },
1, (int[1]){ 3 }, (const long*[1]){ dims_in }, (const complex float*[1]){ in },
weights->tensors[0]);

unmap_cfl(NO, dims_out, out);

nn_weights_free(weights);
}

Expand Down

0 comments on commit b6967ac

Please sign in to comment.