Skip to content

Commit

Permalink
log_du always float
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcza committed Oct 9, 2020
1 parent 933ef38 commit b665d5f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
3 changes: 1 addition & 2 deletions elephant/asset/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,6 @@ def _compile_cuda_template(self, u_length):
asset_cu = cu_template.render(
ASSET_DEBUG=int(self.verbose),
precision=self.precision,
precision_printf='"%f"' if self.precision == "float" else '"%lf"',
N_THREADS=self.cuda_threads,
L=u_length, N=self.n, D=self.d)
return asset_cu
Expand All @@ -621,7 +620,7 @@ def cuda(self, log_du):
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
log_du_path = os.path.join(asset_tmp_folder, "log_du.txt")
P_total_path = os.path.join(asset_tmp_folder, "P_total.txt")
np.savetxt(log_du_path, log_du, fmt="%f")
np.savetxt(log_du_path, log_du, fmt="%.10f")
stdout, stderr = subprocess.Popen(
[asset_bin_path, log_du_path, P_total_path],
stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
Expand Down
24 changes: 12 additions & 12 deletions elephant/asset/asset.template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ __device__ void next_sequence_sorted(int *sequence_sorted, ULL iteration) {
* @param P_out P_total output array of size L
* @param log_du_device input log_du flattened matrix of size L*(D+1)
*/
__global__ void jsf_uniform_orderstat_3d_kernel(asset_float *P_out, asset_float *log_du_device) {
__global__ void jsf_uniform_orderstat_3d_kernel(asset_float *P_out, float *log_du_device) {
unsigned int i, row;

// the row shift of log_du and P_total in the number of elements, between 0 and L
Expand All @@ -72,9 +72,9 @@ __global__ void jsf_uniform_orderstat_3d_kernel(asset_float *P_out, asset_float
// account for the last block width that can be less than L_BLOCK
const unsigned int block_width = (L - l_shift < L_BLOCK) ? (L - l_shift) : L_BLOCK;

extern __shared__ asset_float shared_mem[];
extern __shared__ float shared_mem[];
asset_float *P_total = (asset_float*) shared_mem; // L_BLOCK floats
asset_float *log_du = (asset_float*)&P_total[L_BLOCK]; // L_BLOCK * (D + 1) floats
float *log_du = (float*)&P_total[L_BLOCK]; // L_BLOCK * (D + 1) floats

for (row = threadIdx.x; row < block_width; row += blockDim.x) {
P_total[row] = 0;
Expand Down Expand Up @@ -226,7 +226,7 @@ void print_constants() {
* @param P_total_host a pointer to P_total array to be calculated
* @param log_du_host input flattened L*(D+1) matrix of log_du values
*/
void jsf_uniform_orderstat_3d(asset_float *P_total_host, const asset_float *log_du_host) {
void jsf_uniform_orderstat_3d(asset_float *P_total_host, const float *log_du_host) {
ULL it_todo = create_iteration_table();

asset_float logK_host = 0.f;
Expand Down Expand Up @@ -281,9 +281,9 @@ void jsf_uniform_orderstat_3d(asset_float *P_total_host, const asset_float *log_

printf(">>> it_todo=%llu, grid_size=%llu, N_THREADS=%u\n\n", it_todo, grid_size, n_threads);

asset_float *log_du_device;
cudaMalloc((void**)&log_du_device, sizeof(asset_float) * L * (D + 1));
cudaMemcpy(log_du_device, log_du_host, sizeof(asset_float) * L * (D + 1), cudaMemcpyHostToDevice);
float *log_du_device;
cudaMalloc((void**)&log_du_device, sizeof(float) * L * (D + 1));
cudaMemcpy(log_du_device, log_du_host, sizeof(float) * L * (D + 1), cudaMemcpyHostToDevice);

#if ASSET_DEBUG
print_constants();
Expand All @@ -294,7 +294,7 @@ void jsf_uniform_orderstat_3d(asset_float *P_total_host, const asset_float *log_
cudaDeviceSynchronize();

// Executing kernel
const unsigned long shared_mem_used = sizeof(asset_float) * l_block * (D + 2);
const unsigned long shared_mem_used = sizeof(asset_float) * l_block + sizeof(float) * l_block * (D + 1);
jsf_uniform_orderstat_3d_kernel<<<grid_size, n_threads, shared_mem_used>>>(P_total_device, log_du_device);

// Transfer data back to host memory
Expand Down Expand Up @@ -322,27 +322,27 @@ int main(int argc, char* argv[]) {
return 1;
}

asset_float log_du_host[L * (D + 1)];
float log_du_host[L * (D + 1)];
uint32_t row, col, pos;
for (row = 0; row < L; row++) {
for (col = 0; col <= D; col++) {
pos = row * (D + 1) + col;
int read_floats = fscanf(log_du_file, {{precision_printf}}, log_du_host + pos);
int read_floats = fscanf(log_du_file, "%f", log_du_host + pos);
assert(read_floats == 1);
}
}
fclose(log_du_file);

asset_float P_total[L];
jsf_uniform_orderstat_3d(P_total, (const asset_float*) log_du_host);
jsf_uniform_orderstat_3d(P_total, (const float*) log_du_host);

FILE *P_total_file = fopen(P_total_path, "w");
if (P_total_file == NULL) {
fprintf(stderr, "Could not open '%s' for writing.\n", P_total_path);
return 1;
}
for (col = 0; col < L; col++) {
fprintf(P_total_file, {{precision_printf}}, P_total[col]);
fprintf(P_total_file, "%f\n", P_total[col]);
}
fclose(P_total_file);

Expand Down

0 comments on commit b665d5f

Please sign in to comment.