9
9
#include < thrust/sequence.h>
10
10
#include < thrust/reverse.h>
11
11
12
+ // Define a new system class, as the my_system one is already used with a thrust::sort template definition
13
+ // that calls back into sort.cu
14
+ class my_memory_system : public thrust ::device_execution_policy<my_memory_system>
15
+ {
16
+ public:
17
+ my_memory_system (int )
18
+ : correctly_dispatched(false ),
19
+ num_copies (0 )
20
+ {}
21
+
22
+ my_memory_system (const my_memory_system &other)
23
+ : correctly_dispatched(false ),
24
+ num_copies(other.num_copies + 1 )
25
+ {}
26
+
27
+ void validate_dispatch ()
28
+ {
29
+ correctly_dispatched = (num_copies == 0 );
30
+ }
31
+
32
+ bool is_valid ()
33
+ {
34
+ return correctly_dispatched;
35
+ }
36
+
37
+ private:
38
+ bool correctly_dispatched;
39
+
40
+ // count the number of copies so that we can validate
41
+ // that dispatch does not introduce any
42
+ unsigned int num_copies;
43
+
44
+
45
+ // disallow default construction
46
+ my_memory_system ();
47
+ };
48
+
49
+
12
50
template <typename T1, typename T2>
13
51
bool are_same (const T1 &, const T2 &)
14
52
{
@@ -27,7 +65,7 @@ void TestSelectSystemDifferentTypes()
27
65
{
28
66
using thrust::system ::detail::generic::select_system;
29
67
30
- my_system my_sys (0 );
68
+ my_memory_system my_sys (0 );
31
69
thrust::device_system_tag device_sys;
32
70
33
71
// select_system(my_system, device_system_tag) should return device_system_tag (the minimum tag)
@@ -45,7 +83,7 @@ void TestSelectSystemSameTypes()
45
83
{
46
84
using thrust::system ::detail::generic::select_system;
47
85
48
- my_system my_sys (0 );
86
+ my_memory_system my_sys (0 );
49
87
thrust::device_system_tag device_sys;
50
88
thrust::host_system_tag host_sys;
51
89
@@ -106,20 +144,20 @@ void TestMalloc()
106
144
DECLARE_UNITTEST (TestMalloc);
107
145
108
146
109
- thrust::pointer<void ,my_system >
110
- malloc (my_system &system, std::size_t )
147
+ thrust::pointer<void ,my_memory_system >
148
+ malloc (my_memory_system &system, std::size_t )
111
149
{
112
150
system .validate_dispatch ();
113
151
114
- return thrust::pointer<void ,my_system >();
152
+ return thrust::pointer<void ,my_memory_system >();
115
153
}
116
154
117
155
118
156
void TestMallocDispatchExplicit ()
119
157
{
120
158
const size_t n = 0 ;
121
159
122
- my_system sys (0 );
160
+ my_memory_system sys (0 );
123
161
thrust::malloc (sys, n);
124
162
125
163
ASSERT_EQUAL (true , sys.is_valid ());
@@ -128,17 +166,17 @@ DECLARE_UNITTEST(TestMallocDispatchExplicit);
128
166
129
167
130
168
template <typename Pointer>
131
- void free (my_system &system, Pointer)
169
+ void free (my_memory_system &system, Pointer)
132
170
{
133
171
system .validate_dispatch ();
134
172
}
135
173
136
174
137
175
void TestFreeDispatchExplicit ()
138
176
{
139
- thrust::pointer<my_system ,void > ptr;
177
+ thrust::pointer<my_memory_system ,void > ptr;
140
178
141
- my_system sys (0 );
179
+ my_memory_system sys (0 );
142
180
thrust::free (sys, ptr);
143
181
144
182
ASSERT_EQUAL (true , sys.is_valid ());
@@ -147,14 +185,14 @@ DECLARE_UNITTEST(TestFreeDispatchExplicit);
147
185
148
186
149
187
template <typename T>
150
- thrust::pair<thrust::pointer<T,my_system >, std::ptrdiff_t >
151
- get_temporary_buffer (my_system &system, std::ptrdiff_t n)
188
+ thrust::pair<thrust::pointer<T,my_memory_system >, std::ptrdiff_t >
189
+ get_temporary_buffer (my_memory_system &system, std::ptrdiff_t n)
152
190
{
153
191
system .validate_dispatch ();
154
192
155
193
thrust::device_system_tag device_sys;
156
194
thrust::pair<thrust::pointer<T, thrust::device_system_tag>, std::ptrdiff_t > result = thrust::get_temporary_buffer<T>(device_sys, n);
157
- return thrust::make_pair (thrust::pointer<T,my_system >(result.first .get ()), result.second );
195
+ return thrust::make_pair (thrust::pointer<T,my_memory_system >(result.first .get ()), result.second );
158
196
}
159
197
160
198
@@ -166,7 +204,7 @@ void TestGetTemporaryBufferDispatchExplicit()
166
204
#else
167
205
const size_t n = 9001 ;
168
206
169
- my_system sys (0 );
207
+ my_memory_system sys (0 );
170
208
typedef thrust::pointer<int , thrust::device_system_tag> pointer;
171
209
thrust::pair<pointer, std::ptrdiff_t > ptr_and_sz = thrust::get_temporary_buffer<int >(sys, n);
172
210
@@ -205,7 +243,7 @@ void TestGetTemporaryBufferDispatchImplicit()
205
243
thrust::reverse (vec.begin (), vec.end ());
206
244
207
245
// call something we know will invoke get_temporary_buffer
208
- my_system sys (0 );
246
+ my_memory_system sys (0 );
209
247
thrust::sort (sys, vec.begin (), vec.end ());
210
248
211
249
ASSERT_EQUAL (true , thrust::is_sorted (vec.begin (), vec.end ()));
0 commit comments