forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Multi-tensor axpby kernel for more flexible unscaling (groundwork for p…
…ytorch#163 and pytorch#179 fix)
- Loading branch information
1 parent
74c06d8
commit 5e55200
Showing
6 changed files
with
270 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/AccumulateType.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <ATen/cuda/Exceptions.h> | ||
// Another possibility: | ||
// #include <torch/all.h> | ||
|
||
#include <assert.h> | ||
|
||
#include "type_shim.h" | ||
#include "multi_tensor_apply.cuh" | ||
|
||
#define BLOCK_SIZE 512 | ||
#define ILP 4 | ||
|
||
template<typename x_t, typename y_t, typename out_t> | ||
struct AxpbyFunctor | ||
{ | ||
__device__ __forceinline__ void operator()( | ||
int chunk_size, | ||
volatile int* noop_gmem, | ||
TensorListMetadata<3>& tl, | ||
float a, | ||
float b) | ||
{ | ||
// I'd like this kernel to propagate infs/nans. | ||
// if(*noop_gmem == 1) | ||
// return; | ||
|
||
int tensor_loc = tl.block_to_tensor[blockIdx.x]; | ||
int chunk_idx = tl.block_to_chunk[blockIdx.x]; | ||
int n = tl.sizes[tensor_loc]; | ||
|
||
x_t* x = (x_t*)tl.addresses[0][tensor_loc]; | ||
x += chunk_idx*chunk_size; | ||
|
||
y_t* y = (y_t*)tl.addresses[1][tensor_loc]; | ||
y += chunk_idx*chunk_size; | ||
|
||
out_t* out = (out_t*)tl.addresses[2][tensor_loc]; | ||
out += chunk_idx*chunk_size; | ||
|
||
n -= chunk_idx*chunk_size; | ||
|
||
// Non-divergent exit condition for __syncthreads, not necessary here | ||
float xs[ILP]; | ||
float ys[ILP]; | ||
for(int i_start = 0; | ||
i_start < n && i_start < chunk_size; | ||
i_start += blockDim.x*ILP) | ||
{ | ||
#pragma unroll | ||
for(int ii = 0; ii < ILP; ii++) | ||
{ | ||
xs[ii] = 0; | ||
ys[ii] = 0; | ||
int i = i_start + threadIdx.x + ii*blockDim.x; | ||
if(i < n && i < chunk_size) | ||
{ | ||
xs[ii] = static_cast<float>(x[i]); | ||
ys[ii] = static_cast<float>(y[i]); | ||
} | ||
} | ||
|
||
// see note in multi_tensor_scale_kernel.cu | ||
#pragma unroll | ||
for(int ii = 0; ii < ILP; ii++) | ||
{ | ||
int i = i_start + threadIdx.x + ii*blockDim.x; | ||
if(i < n && i < chunk_size) | ||
if(isfinite(xs[ii]) && isfinite(ys[ii])) | ||
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]); | ||
else | ||
{ | ||
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]); | ||
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
void multi_tensor_axpby_cuda( | ||
int chunk_size, | ||
at::Tensor noop_flag, | ||
std::vector<std::vector<at::Tensor>> tensor_lists, | ||
float a, | ||
float b) | ||
{ | ||
using namespace at; | ||
// The output (downscaled) type is always float. | ||
// If build times suffer, think about where to put this dispatch, | ||
// and what logic should be moved out of multi_tensor_apply. | ||
|
||
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", | ||
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", | ||
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", | ||
multi_tensor_apply<3>( | ||
BLOCK_SIZE, | ||
chunk_size, | ||
noop_flag, | ||
tensor_lists, | ||
AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(), | ||
a, | ||
b); ))) | ||
|
||
AT_CUDA_CHECK(cudaGetLastError()); | ||
|
||
// AT_CUDA_CHECK(cudaDeviceSynchronize()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import unittest | ||
|
||
import functools as ft | ||
import itertools as it | ||
|
||
from apex import amp | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
from utils import common_init, HALF, FLOAT,\ | ||
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT | ||
|
||
try: | ||
import amp_C | ||
from amp_C import multi_tensor_axpby | ||
from apex.multi_tensor_apply import MultiTensorApply | ||
disabled = False | ||
except ImportError as err: | ||
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) | ||
disabled = True | ||
|
||
|
||
class TestMultiTensorAxpby(unittest.TestCase): | ||
|
||
def setUp(self): | ||
common_init(self) | ||
|
||
self.a = 2.0 | ||
self.b = 8.0 | ||
self.xval = 4.0 | ||
self.yval = 16.0 | ||
self.overflow_buf = torch.cuda.IntTensor(1).zero_() | ||
self.ref = torch.cuda.FloatTensor([136.0]) | ||
|
||
def tearDown(self): | ||
pass | ||
|
||
# The tensor creation here is written for convenience, not speed. | ||
def axpby(self, sizea, sizeb, applier, repeat_tensors, | ||
x_type, y_type, out_type, inplace=False): | ||
self.overflow_buf.zero_() | ||
t1 = torch.cuda.FloatTensor(sizea).fill_(1.0) | ||
t2 = torch.cuda.FloatTensor(sizeb).fill_(1.0) | ||
|
||
y_list = [] | ||
for i in range(repeat_tensors): | ||
y_list += [t1.clone().to(y_type)*self.yval, t2.clone().to(y_type)*self.yval] | ||
|
||
x_list = [x.clone().to(x_type)*(self.xval/self.yval) for x in y_list] | ||
|
||
if inplace: | ||
out_list = y_list | ||
else: | ||
out_list = [out.clone().to(out_type)*3.0 for out in y_list] | ||
|
||
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b) | ||
|
||
self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]), | ||
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, | ||
x_type, y_type, out_type, inplace)) | ||
self.assertTrue(self.overflow_buf.item() == 0, | ||
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, | ||
x_type, y_type, out_type, inplace)) | ||
|
||
# def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False): | ||
# self.overflow_buf.zero_() | ||
# a = torch.cuda.FloatTensor(sizea).fill_(self.scale) | ||
# b = torch.cuda.FloatTensor(sizeb).fill_(self.scale) | ||
|
||
# out_list = [] | ||
# for i in range(repeat_tensors): | ||
# out_list += [a.clone().to(out_type), b.clone().to(out_type)] | ||
|
||
# if inplace: | ||
# in_list = out_list | ||
# else: | ||
# in_list = [out.clone().to(in_type) for out in out_list] | ||
|
||
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) | ||
|
||
# self.overflow_buf.zero_() | ||
# in_list[t][ind] = val | ||
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) | ||
# self.assertTrue(self.overflow_buf.item()) | ||
|
||
@unittest.skipIf(disabled, "amp_C is unavailable") | ||
def test_fuzz(self): | ||
input_size_pairs = ( | ||
(7777*77, 555*555), | ||
(777, 555), | ||
(555, 2048*32+1), | ||
(2048*32+1, 555), | ||
(555, 2048*32), | ||
(2048*32, 555), | ||
(33333, 555), | ||
(555, 33333)) | ||
appliers = ( | ||
MultiTensorApply(2048*32), | ||
MultiTensorApply(333), | ||
MultiTensorApply(33333)) | ||
repeat_tensors = ( | ||
1, | ||
55) | ||
|
||
for sizea, sizeb in input_size_pairs: | ||
for applier in appliers: | ||
for repeat in repeat_tensors: | ||
for x_type in (torch.float32, torch.float16): | ||
for y_type in (torch.float32, torch.float16): | ||
for out_type in (torch.float32, torch.float16): | ||
for inplace in (True, False): | ||
if inplace is True and (y_type is not out_type): | ||
continue | ||
else: | ||
self.axpby(sizea, sizeb, applier, repeat, | ||
x_type, y_type, out_type, inplace=inplace) | ||
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, | ||
# 0, 0, float('nan'), inplace=inplace) | ||
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, | ||
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace) | ||
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, | ||
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters