Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit b1098f0

Browse files
committed
Merge pull request #612 from kperelygin/upstream
Defined a new my_memory_system system class for memory.cu
2 parents bb903c0 + 86e24e8 commit b1098f0

File tree

1 file changed

+52
-14
lines changed

1 file changed

+52
-14
lines changed

testing/memory.cu

+52-14
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,44 @@
99
#include <thrust/sequence.h>
1010
#include <thrust/reverse.h>
1111

12+
// Define a new system class, as the my_system one is already used with a thrust::sort template definition
13+
// that calls back into sort.cu
14+
class my_memory_system : public thrust::device_execution_policy<my_memory_system>
15+
{
16+
public:
17+
my_memory_system(int)
18+
: correctly_dispatched(false),
19+
num_copies(0)
20+
{}
21+
22+
my_memory_system(const my_memory_system &other)
23+
: correctly_dispatched(false),
24+
num_copies(other.num_copies + 1)
25+
{}
26+
27+
void validate_dispatch()
28+
{
29+
correctly_dispatched = (num_copies == 0);
30+
}
31+
32+
bool is_valid()
33+
{
34+
return correctly_dispatched;
35+
}
36+
37+
private:
38+
bool correctly_dispatched;
39+
40+
// count the number of copies so that we can validate
41+
// that dispatch does not introduce any
42+
unsigned int num_copies;
43+
44+
45+
// disallow default construction
46+
my_memory_system();
47+
};
48+
49+
1250
template<typename T1, typename T2>
1351
bool are_same(const T1 &, const T2 &)
1452
{
@@ -27,7 +65,7 @@ void TestSelectSystemDifferentTypes()
2765
{
2866
using thrust::system::detail::generic::select_system;
2967

30-
my_system my_sys(0);
68+
my_memory_system my_sys(0);
3169
thrust::device_system_tag device_sys;
3270

3371
// select_system(my_system, device_system_tag) should return device_system_tag (the minimum tag)
@@ -45,7 +83,7 @@ void TestSelectSystemSameTypes()
4583
{
4684
using thrust::system::detail::generic::select_system;
4785

48-
my_system my_sys(0);
86+
my_memory_system my_sys(0);
4987
thrust::device_system_tag device_sys;
5088
thrust::host_system_tag host_sys;
5189

@@ -106,20 +144,20 @@ void TestMalloc()
106144
DECLARE_UNITTEST(TestMalloc);
107145

108146

109-
thrust::pointer<void,my_system>
110-
malloc(my_system &system, std::size_t)
147+
thrust::pointer<void,my_memory_system>
148+
malloc(my_memory_system &system, std::size_t)
111149
{
112150
system.validate_dispatch();
113151

114-
return thrust::pointer<void,my_system>();
152+
return thrust::pointer<void,my_memory_system>();
115153
}
116154

117155

118156
void TestMallocDispatchExplicit()
119157
{
120158
const size_t n = 0;
121159

122-
my_system sys(0);
160+
my_memory_system sys(0);
123161
thrust::malloc(sys, n);
124162

125163
ASSERT_EQUAL(true, sys.is_valid());
@@ -128,17 +166,17 @@ DECLARE_UNITTEST(TestMallocDispatchExplicit);
128166

129167

130168
template<typename Pointer>
131-
void free(my_system &system, Pointer)
169+
void free(my_memory_system &system, Pointer)
132170
{
133171
system.validate_dispatch();
134172
}
135173

136174

137175
void TestFreeDispatchExplicit()
138176
{
139-
thrust::pointer<my_system,void> ptr;
177+
thrust::pointer<my_memory_system,void> ptr;
140178

141-
my_system sys(0);
179+
my_memory_system sys(0);
142180
thrust::free(sys, ptr);
143181

144182
ASSERT_EQUAL(true, sys.is_valid());
@@ -147,14 +185,14 @@ DECLARE_UNITTEST(TestFreeDispatchExplicit);
147185

148186

149187
template<typename T>
150-
thrust::pair<thrust::pointer<T,my_system>, std::ptrdiff_t>
151-
get_temporary_buffer(my_system &system, std::ptrdiff_t n)
188+
thrust::pair<thrust::pointer<T,my_memory_system>, std::ptrdiff_t>
189+
get_temporary_buffer(my_memory_system &system, std::ptrdiff_t n)
152190
{
153191
system.validate_dispatch();
154192

155193
thrust::device_system_tag device_sys;
156194
thrust::pair<thrust::pointer<T, thrust::device_system_tag>, std::ptrdiff_t> result = thrust::get_temporary_buffer<T>(device_sys, n);
157-
return thrust::make_pair(thrust::pointer<T,my_system>(result.first.get()), result.second);
195+
return thrust::make_pair(thrust::pointer<T,my_memory_system>(result.first.get()), result.second);
158196
}
159197

160198

@@ -166,7 +204,7 @@ void TestGetTemporaryBufferDispatchExplicit()
166204
#else
167205
const size_t n = 9001;
168206

169-
my_system sys(0);
207+
my_memory_system sys(0);
170208
typedef thrust::pointer<int, thrust::device_system_tag> pointer;
171209
thrust::pair<pointer, std::ptrdiff_t> ptr_and_sz = thrust::get_temporary_buffer<int>(sys, n);
172210

@@ -205,7 +243,7 @@ void TestGetTemporaryBufferDispatchImplicit()
205243
thrust::reverse(vec.begin(), vec.end());
206244

207245
// call something we know will invoke get_temporary_buffer
208-
my_system sys(0);
246+
my_memory_system sys(0);
209247
thrust::sort(sys, vec.begin(), vec.end());
210248

211249
ASSERT_EQUAL(true, thrust::is_sorted(vec.begin(), vec.end()));

0 commit comments

Comments
 (0)