-
Notifications
You must be signed in to change notification settings - Fork 304
/
Copy pathmodel.h
111 lines (94 loc) · 2.98 KB
/
model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//-----------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//-----------------------------------------------------------------------------
#pragma once
namespace pydml
{
struct CompiledModel
{
CompiledModel(
dml::Graph& graph,
DML_EXECUTION_FLAGS flags,
std::vector<dml::Expression>& outputs
) :
op(graph.Compile(flags, outputs))
{}
Microsoft::WRL::ComPtr<IDMLCompiledOperator> op;
};
struct TensorData
{
TensorData(py::buffer_info const& info) :
itemSize(info.itemsize),
format(info.format),
dimensions(info.ndim),
shape(info.shape),
strides(info.strides)
{
auto sizeInBytes = Size();
buffer.resize(sizeInBytes);
memcpy(buffer.data(), info.ptr, sizeInBytes);
// Numpy strides use bytes.
std::for_each(strides.begin(), strides.end(), [=](auto& i) {i *= itemSize; });
}
TensorData(dml::TensorDesc* desc) :
itemSize(sizeof(float)),
format(py::format_descriptor<float>::format()),
dimensions(desc->sizes.size())
{
for (auto size : desc->sizes)
{
shape.push_back(static_cast<ssize_t>(size));
}
if (desc->strides)
{
for (auto stride : *desc->strides)
{
strides.push_back(static_cast<ssize_t>(stride));
}
}
else
{
// Use default descending packed strides.
strides.resize(shape.size());
ssize_t stride = 1;
for (size_t i = strides.size(); i-- > 0; )
{
strides[i] = stride;
stride *= shape[i];
}
}
// Numpy strides use bytes.
std::for_each(strides.begin(), strides.end(), [=](auto& i) {i *= itemSize; });
buffer.resize(static_cast<size_t>(desc->totalTensorSizeInBytes));
}
TensorData() {}
void* Get() const { return static_cast<void*>(const_cast<byte*>(buffer.data())); }
size_t Size() const
{
size_t size = 1;
for (auto length : shape)
{
size *= length;
}
return size * itemSize;
}
std::vector<byte> buffer;
size_t itemSize;
std::string format;
size_t dimensions;
std::vector<ssize_t> shape;
std::vector<ssize_t> strides;
};
struct Binding
{
explicit Binding(dml::Expression& expression, py::buffer_info const& info)
: desc(expression.GetOutputDesc()),
data(info)
{}
Binding() = default;
dml::TensorDesc desc;
TensorData data;
};
}