-
Notifications
You must be signed in to change notification settings - Fork 3
/
advisor.cpp
70 lines (62 loc) · 2.8 KB
/
advisor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include "advisor.h"
namespace cuFFTAdvisor {
std::vector<BenchmarkResult const *> *Advisor::benchmark(
int device, int x, int y, int z, int n, Tristate::Tristate isBatched,
Tristate::Tristate isFloat, Tristate::Tristate isForward,
Tristate::Tristate isInPlace, Tristate::Tristate isReal) {
Validator::validate(x, y, z, n, device);
std::vector<Transform const *> transforms;
TransformGenerator generator;
generator.generate(device, x, y, z, n, isBatched, isFloat, isForward,
isInPlace, isReal, transforms);
std::vector<BenchmarkResult const *> *result = benchmark(transforms);
return result;
}
int Advisor::getMaxMemory(int device, int size) {
if (size == INT_MAX) {
return std::ceil(toMB(getTotalMemory(device)));
}
return size;
}
std::vector<Transform const *> *Advisor::recommend(
int howMany, int device, int x, int y, int z, int n,
Tristate::Tristate isBatched, Tristate::Tristate isFloat,
Tristate::Tristate isForward, Tristate::Tristate isInPlace,
Tristate::Tristate isReal, int maxSignalInc, int maxMemory,
bool allowTransposition, bool squareOnly, bool crop) {
Validator::validate(device);
maxMemory = getMaxMemory(device, maxMemory);
Validator::validate(x, y, z, n, device, maxSignalInc, maxMemory, allowTransposition, squareOnly);
GeneralTransform tr = GeneralTransform(device, x, y, z, n, isBatched, isFloat,
isForward, isInPlace, isReal);
SizeOptimizer optimizer(CudaVersion::V_8, tr, allowTransposition);
std::vector<const Transform *> *result =
optimizer.optimize(howMany, maxSignalInc, maxMemory, squareOnly, crop);
return result;
}
std::vector<BenchmarkResult const *> *Advisor::find(
int howMany, int device, int x, int y, int z, int n,
Tristate::Tristate isBatched, Tristate::Tristate isFloat,
Tristate::Tristate isForward, Tristate::Tristate isInPlace,
Tristate::Tristate isReal, int maxSignalInc, int maxMemory,
bool allowTransposition, bool squareOnly, bool crop) {
std::vector<Transform const *> *candidates =
recommend(howMany, device, x, y, z, n, isBatched, isFloat, isForward,
isInPlace, isReal, maxSignalInc, maxMemory, allowTransposition,
squareOnly, crop);
std::vector<BenchmarkResult const *> *result = benchmark(*candidates);
std::sort(result->begin(), result->end(), BenchmarkResult::execSort);
delete candidates;
return result;
}
std::vector<BenchmarkResult const *> *Advisor::benchmark(
std::vector<Transform const *> &transforms) {
std::vector<BenchmarkResult const *> *results =
new std::vector<BenchmarkResult const *>();
int size = transforms.size();
for (int i = 0; i < size; i++) {
results->push_back(Benchmarker::benchmark(transforms.at(i)));
}
return results;
}
} // namespace cuFFTAdvisor