Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Split into low and high level wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed May 2, 2019
1 parent dcc92d3 commit 602d13f
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 101 deletions.
2 changes: 1 addition & 1 deletion src/tensor/CUTENSOR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module CUTENSOR

import CUDAapi

import CUDAdrv: CUDAdrv, CuContext, CuStream_t, CuPtr, PtrOrCuPtr, CU_NULL
import CUDAdrv: CUDAdrv, CuContext

using ..CuArrays
using ..CuArrays: libcutensor, active_context
Expand Down
117 changes: 108 additions & 9 deletions src/tensor/libcutensor.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,121 @@
# Julia wrapper for header: /usr/local/cuda/include/cusparse.h
# low-level wrappers of the CUTENSOR library

using CUDAdrv: CuStream_t, CuPtr, PtrOrCuPtr, CU_NULL

cutensorGetErrorString(status) = ccall((:cutensorGetErrorString,libcutensor), Ptr{UInt8},
(cutensorStatus_t,), status)

#helper functions
function cutensorCreate()
handle = Ref{cutensorHandle_t}()
@check ccall((:cutensorCreate, libcutensor), cutensorStatus_t, (Ptr{cutensorHandle_t},), handle)
@check ccall((:cutensorCreate, libcutensor), cutensorStatus_t,
(Ptr{cutensorHandle_t},), handle)
handle[]
end

function cutensorDestroy(handle)
@check ccall((:cutensorDestroy, libcutensor), cutensorStatus_t, (cutensorHandle_t,), handle)
end
function cutensorGetVersion(handle, version)
@check ccall((:cutensorGetVersion, libcutensor), cutensorStatus_t, (cutensorHandle_t, Ptr{Cint}), handle, version)
end
function cutensorCreateTensorDescriptor(numModes::Cint, extent::Vector{Int64}, stride::Vector{Int64}, T::cudaDataType_t, unaryOp::cutensorOperator_t, vectorWidth::Cint, vectorModeIndex::Cint)

function cutensorCreateTensorDescriptor(numModes::Cint,
extent::Vector{Int64},
stride::Vector{Int64},
T::cudaDataType_t,
unaryOp::cutensorOperator_t,
vectorWidth::Cint,
vectorModeIndex::Cint)
desc = Ref{cutensorTensorDescriptor_t}(C_NULL)
@check ccall((:cutensorCreateTensorDescriptor, libcutensor), cutensorStatus_t, (Ref{cutensorTensorDescriptor_t}, Cint, Ptr{Int64}, Ptr{Int64}, cudaDataType_t, cutensorOperator_t, Cint, Cint), desc, numModes, extent, stride, T, unaryOp, vectorWidth, vectorModeIndex)
@check ccall((:cutensorCreateTensorDescriptor, libcutensor), cutensorStatus_t,
(Ref{cutensorTensorDescriptor_t}, Cint, Ptr{Int64}, Ptr{Int64},
cudaDataType_t, cutensorOperator_t, Cint, Cint),
desc, numModes, extent, stride, T, unaryOp, vectorWidth, vectorModeIndex)
return desc[]
end

function cutensorDestroyTensorDescriptor(desc::cutensorTensorDescriptor_t)
@check ccall((:cutensorDestroyTensorDescriptor, libcutensor), cutensorStatus_t, (cutensorTensorDescriptor_t,), desc)
@check ccall((:cutensorDestroyTensorDescriptor, libcutensor), cutensorStatus_t,
(cutensorTensorDescriptor_t,), desc)
end

function cutensorElementwiseTrinary(handle,
alpha, A, descA, modeA,
beta, B, descB, modeB,
gamma, C, descC, modeC,
D, descD, modeD,
opAB, opABC, typeCompute, stream)
@check ccall((:cutensorElementwiseTrinary,libcutensor), cutensorStatus_t,
(cutensorHandle_t, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t,
Ptr{Cint}, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t,
Ptr{Cint}, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t,
Ptr{Cint}, CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
cutensorOperator_t, cutensorOperator_t, cudaDataType_t, CuStream_t),
handle, alpha, A, descA, modeA, [beta], B, descB, modeB, gamma, C, descC,
modeC, D, descD, modeD, opAB, opABC, typeCompute, stream)
end

function cutensorElementwiseBinary(handle,
alpha, A, descA, modeA,
gamma, C, descC, modeC,
D, descD, modeD,
opAC, typeCompute, stream)
@check ccall((:cutensorElementwiseBinary,libcutensor), cutensorStatus_t,
(cutensorHandle_t, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t,
Ptr{Cint}, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t,
Ptr{Cint}, CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
cutensorOperator_t, cudaDataType_t, CuStream_t),
handle, alpha, A, descA, modeA, gamma, C, descC, modeC, D, descD, modeD,
opAC, typeCompute, stream)
end

function cutensorPermutation(handle,
alpha, A, descA, modeA,
B, descB, modeB,
typeCompute, stream)
@check ccall((:cutensorPermutation,libcutensor), cutensorStatus_t,
(cutensorHandle_t, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t,
Ptr{Cint}, CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
cudaDataType_t, CuStream_t),
handle, alpha, A, descA, modeA, B, descB, modeB, typeCompute, stream)
end

function cutensorContraction(handle,
alpha, A, descA, modeA,
B, descB, modeB,
beta, C, descC, modeC,
D, descD, modeD,
opOut, typeCompute, algo, workspace, workspaceSize, stream)
@check ccall((:cutensorContraction,libcutensor), cutensorStatus_t,
(cutensorHandle_t, Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
Ptr{Cvoid}, CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
cutensorOperator_t, cudaDataType_t, cutensorAlgo_t, CuPtr{Cvoid},
UInt64, CuStream_t),
handle, alpha, A, descA, modeA, B, descB, modeB, beta, C, descC,
modeC, D, descD, modeD, opOut, typeCompute, algo, workspace, workspaceSize,
stream)
end

function cutensorContractionGetWorkspace(handle,
A, descA, modeA,
B, descB, modeB,
C, descC, modeC,
D, descD, modeD,
opOut, typeCompute, algo, pref, workspaceSize)
@check ccall((:cutensorContractionGetWorkspace,libcutensor), cutensorStatus_t,
(cutensorHandle_t, CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
CuPtr{Cvoid}, cutensorTensorDescriptor_t, Ptr{Cint},
cutensorOperator_t, cudaDataType_t, cutensorAlgo_t, cutensorWorksizePreference_t,
Ptr{UInt64}),
handle, A, descA, modeA, B, descB, modeB, C, descC, modeC,
D, descD, modeD, opOut, typeCompute, algo, pref, workspaceSize)
end

function cutensorContractionMaxAlgos()
max_algos = Ref{Cint}()
@check ccall((:cutensorContractionMaxAlgos,libcutensor), cutensorStatus_t,
(Ptr{Cint},), max_algos)
return max_algos
end

Loading

0 comments on commit 602d13f

Please sign in to comment.