Skip to content

Commit

Permalink
Multi-tensor axpby kernel for more flexible unscaling (groundwork for p…
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Mar 19, 2019
1 parent 74c06d8 commit 5e55200
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 3 deletions.
9 changes: 9 additions & 0 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);

void multi_tensor_axpby_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float a,
float b);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
}
110 changes: 110 additions & 0 deletions csrc/multi_tensor_axpby_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>

#include <assert.h>

#include "type_shim.h"
#include "multi_tensor_apply.cuh"

#define BLOCK_SIZE 512
#define ILP 4

template<typename x_t, typename y_t, typename out_t>
struct AxpbyFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<3>& tl,
float a,
float b)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;

y_t* y = (y_t*)tl.addresses[1][tensor_loc];
y += chunk_idx*chunk_size;

out_t* out = (out_t*)tl.addresses[2][tensor_loc];
out += chunk_idx*chunk_size;

n -= chunk_idx*chunk_size;

// Non-divergent exit condition for __syncthreads, not necessary here
float xs[ILP];
float ys[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
xs[ii] = 0;
ys[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
xs[ii] = static_cast<float>(x[i]);
ys[ii] = static_cast<float>(y[i]);
}
}

// see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
if(isfinite(xs[ii]) && isfinite(ys[ii]))
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
else
{
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]);
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
}
}
}
}
};

void multi_tensor_axpby_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float a,
float b)
{
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.

DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda",
DISPATCH_FLOAT_AND_HALF(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AxpbyFunctor<scalar_t_0, scalar_t_1, scalar_t_2>(),
a,
b); )))

AT_CUDA_CHECK(cudaGetLastError());

// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
19 changes: 19 additions & 0 deletions csrc/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,22 @@ struct TypeShim
// Enable dispatch switch statements to take *this directly for post-3aeb78
operator at::ScalarType(){ return payload.scalarType(); };
};

#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@
ext_modules.append(
CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_scale_kernel.cu'],
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu'],
extra_compile_args={'cxx': ['-O3'],
'nvcc':['-lineinfo',
'-O3',
# '--resource-usage',
'--use_fast_math']}))
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
Expand Down
128 changes: 128 additions & 0 deletions tests/L0/run_amp/test_multi_tensor_axpby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

try:
import amp_C
from amp_C import multi_tensor_axpby
from apex.multi_tensor_apply import MultiTensorApply
disabled = False
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True


class TestMultiTensorAxpby(unittest.TestCase):

def setUp(self):
common_init(self)

self.a = 2.0
self.b = 8.0
self.xval = 4.0
self.yval = 16.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([136.0])

def tearDown(self):
pass

# The tensor creation here is written for convenience, not speed.
def axpby(self, sizea, sizeb, applier, repeat_tensors,
x_type, y_type, out_type, inplace=False):
self.overflow_buf.zero_()
t1 = torch.cuda.FloatTensor(sizea).fill_(1.0)
t2 = torch.cuda.FloatTensor(sizeb).fill_(1.0)

y_list = []
for i in range(repeat_tensors):
y_list += [t1.clone().to(y_type)*self.yval, t2.clone().to(y_type)*self.yval]

x_list = [x.clone().to(x_type)*(self.xval/self.yval) for x in y_list]

if inplace:
out_list = y_list
else:
out_list = [out.clone().to(out_type)*3.0 for out in y_list]

applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b)

self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]),
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
x_type, y_type, out_type, inplace))
self.assertTrue(self.overflow_buf.item() == 0,
msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
x_type, y_type, out_type, inplace))

# def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
# self.overflow_buf.zero_()
# a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
# b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)

# out_list = []
# for i in range(repeat_tensors):
# out_list += [a.clone().to(out_type), b.clone().to(out_type)]

# if inplace:
# in_list = out_list
# else:
# in_list = [out.clone().to(in_type) for out in out_list]

# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)

# self.overflow_buf.zero_()
# in_list[t][ind] = val
# applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
# self.assertTrue(self.overflow_buf.item())

@unittest.skipIf(disabled, "amp_C is unavailable")
def test_fuzz(self):
input_size_pairs = (
(7777*77, 555*555),
(777, 555),
(555, 2048*32+1),
(2048*32+1, 555),
(555, 2048*32),
(2048*32, 555),
(33333, 555),
(555, 33333))
appliers = (
MultiTensorApply(2048*32),
MultiTensorApply(333),
MultiTensorApply(33333))
repeat_tensors = (
1,
55)

for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for x_type in (torch.float32, torch.float16):
for y_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for inplace in (True, False):
if inplace is True and (y_type is not out_type):
continue
else:
self.axpby(sizea, sizeb, applier, repeat,
x_type, y_type, out_type, inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 0, 0, float('nan'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)



if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions tests/L0/run_amp/test_multi_tensor_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
class TestMultiTensorScale(unittest.TestCase):

def setUp(self):
common_init(self)
self.scale = 4.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([1.0])

common_init(self)

def tearDown(self):
pass

Expand Down

0 comments on commit 5e55200

Please sign in to comment.