forked from torch/cutorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinit.lua
34 lines (27 loc) · 983 Bytes
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
require "torch"
cutorch = paths.require("libcutorch")
torch.CudaStorage.__tostring__ = torch.FloatStorage.__tostring__
torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__
include('Tensor.lua')
include('FFI.lua')
include('test.lua')
local unpack = unpack or table.unpack
function cutorch.withDevice(newDeviceID, closure)
local curDeviceID = cutorch.getDevice()
cutorch.setDevice(newDeviceID)
local vals = {pcall(closure)}
cutorch.setDevice(curDeviceID)
if vals[1] then
return unpack(vals, 2)
end
error(unpack(vals, 2))
end
-- Creates a FloatTensor using the CudaHostAllocator.
-- Accepts either a LongStorage or a sequence of numbers.
function cutorch.createCudaHostTensor(...)
local size = torch.LongTensor(torch.isStorage(...) and ... or {...})
local storage = torch.FloatStorage(cutorch.CudaHostAllocator, size:prod())
return torch.FloatTensor(storage, 1, size:storage())
end
cutorch.setHeapTracking(true)
return cutorch