Skip to content

Commit

Permalink
Metal: more compact evaluate_solution
Browse files Browse the repository at this point in the history
New measurements with half of the runs using non-perceptual mode:
Quality 0: CPU 269.0, GPU 151.4
Quality 1: CPU 48.5, GPU 57.8
Quality 3: CPU 33.6, GPU 44.7
  • Loading branch information
aras-p committed Jan 6, 2021
1 parent ad94b29 commit 7c4d3ad
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 46 deletions.
80 changes: 35 additions & 45 deletions src/shaders/metal/bc7e.metal
Original file line number Diff line number Diff line change
Expand Up @@ -660,26 +660,16 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
color_quad_i actualMaxColor = scale_color(&quantMaxColor, pParams);

const uint32_t N = pParams->m_num_selector_weights;
const uint32_t nc = pParams->m_has_alpha ? 4 : 3;

float total_errf = 0;

float wr = pParams->m_weights[0];
float wg = pParams->m_weights[1];
float wb = pParams->m_weights[2];
float wa = pParams->m_weights[3];

float4 ww = float4(pParams->m_weights);
color_quad_f weightedColors[16];
weightedColors[0] = float4(actualMinColor);
weightedColors[N-1] = float4(actualMaxColor);

for (uint32_t i = 1; i < (N - 1); i++)
{
for (uint32_t j = 0; j < nc; j++)
{
float w = tables->g_bc7_weights[pParams->m_weights_index+i];
weightedColors[i][j] = floor((weightedColors[0][j] * (64.0f - w) + weightedColors[N - 1][j] * w + 32) * (1.0f / 64.0f));
}
float w = tables->g_bc7_weights[pParams->m_weights_index+i];
weightedColors[i] = floor((weightedColors[0] * (64.0f - w) + weightedColors[N - 1] * w + 32) * (1.0f / 64.0f));
}

uchar selectors[16];
Expand All @@ -706,9 +696,9 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
float best_sel0 = best_sel - 1;

float3 d0 = weightedColors[(int)best_sel0].rgb - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b;
float3 d1 = weightedColors[(int)best_sel].rgb - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b;

float min_err = min(err0, err1);
total_errf += min_err;
Expand All @@ -726,13 +716,13 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
int best_sel;
{
float3 d0 = weightedColors[0].rgb - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b;
float3 d1 = weightedColors[1].rgb - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b;
float3 d2 = weightedColors[2].rgb - pp;
float err2 = wr * d2.r * d2.r + wg * d2.g * d2.g + wb * d2.b * d2.b;
float err2 = ww.r * d2.r * d2.r + ww.g * d2.g * d2.g + ww.b * d2.b * d2.b;
float3 d3 = weightedColors[3].rgb - pp;
float err3 = wr * d3.r * d3.r + wg * d3.g * d3.g + wb * d3.b * d3.b;
float err3 = ww.r * d3.r * d3.r + ww.g * d3.g * d3.g + ww.b * d3.b * d3.b;

best_err = min(min(min(err0, err1), err2), err3);
best_sel = select(0, 1, best_err == err1);
Expand All @@ -741,13 +731,13 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
}
{
float3 d0 = weightedColors[4].rgb - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b;
float3 d1 = weightedColors[5].rgb - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b;
float3 d2 = weightedColors[6].rgb - pp;
float err2 = wr * d2.r * d2.r + wg * d2.g * d2.g + wb * d2.b * d2.b;
float err2 = ww.r * d2.r * d2.r + ww.g * d2.g * d2.g + ww.b * d2.b * d2.b;
float3 d3 = weightedColors[7].rgb - pp;
float err3 = wr * d3.r * d3.r + wg * d3.g * d3.g + wb * d3.b * d3.b;
float err3 = ww.r * d3.r * d3.r + ww.g * d3.g * d3.g + ww.b * d3.b * d3.b;

best_err = min(best_err, min(min(min(err0, err1), err2), err3));
best_sel = select(best_sel, 4, best_err == err0);
Expand All @@ -768,13 +758,13 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
float3 pp = float3(pPixels[i].rgb);

float3 d0 = weightedColors[0].rgb - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b;
float3 d1 = weightedColors[1].rgb - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b;
float3 d2 = weightedColors[2].rgb - pp;
float err2 = wr * d2.r * d2.r + wg * d2.g * d2.g + wb * d2.b * d2.b;
float err2 = ww.r * d2.r * d2.r + ww.g * d2.g * d2.g + ww.b * d2.b * d2.b;
float3 d3 = weightedColors[3].rgb - pp;
float err3 = wr * d3.r * d3.r + wg * d3.g * d3.g + wb * d3.b * d3.b;
float err3 = ww.r * d3.r * d3.r + ww.g * d3.g * d3.g + ww.b * d3.b * d3.b;

float best_err = min(min(min(err0, err1), err2), err3);
int best_sel = select(0, 1, best_err == err1);
Expand Down Expand Up @@ -807,9 +797,9 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
float best_sel0 = best_sel - 1;

float4 d0 = weightedColors[(int)best_sel0] - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b + wa * d0.a * d0.a;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b + ww.a * d0.a * d0.a;
float4 d1 = weightedColors[(int)best_sel] - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b + wa * d1.a * d1.a;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b + ww.a * d1.a * d1.a;

float min_err = min(err0, err1);
total_errf += min_err;
Expand All @@ -827,13 +817,13 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
int best_sel;
{
float4 d0 = weightedColors[0] - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b + wa * d0.a * d0.a;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b + ww.a * d0.a * d0.a;
float4 d1 = weightedColors[1] - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b + wa * d1.a * d1.a;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b + ww.a * d1.a * d1.a;
float4 d2 = weightedColors[2] - pp;
float err2 = wr * d2.r * d2.r + wg * d2.g * d2.g + wb * d2.b * d2.b + wa * d2.a * d2.a;
float err2 = ww.r * d2.r * d2.r + ww.g * d2.g * d2.g + ww.b * d2.b * d2.b + ww.a * d2.a * d2.a;
float4 d3 = weightedColors[3] - pp;
float err3 = wr * d3.r * d3.r + wg * d3.g * d3.g + wb * d3.b * d3.b + wa * d3.a * d3.a;
float err3 = ww.r * d3.r * d3.r + ww.g * d3.g * d3.g + ww.b * d3.b * d3.b + ww.a * d3.a * d3.a;

best_err = min(min(min(err0, err1), err2), err3);
best_sel = select(0, 1, best_err == err1);
Expand All @@ -842,13 +832,13 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
}
{
float4 d0 = weightedColors[4] - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b + wa * d0.a * d0.a;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b + ww.a * d0.a * d0.a;
float4 d1 = weightedColors[5] - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b + wa * d1.a * d1.a;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b + ww.a * d1.a * d1.a;
float4 d2 = weightedColors[6] - pp;
float err2 = wr * d2.r * d2.r + wg * d2.g * d2.g + wb * d2.b * d2.b + wa * d2.a * d2.a;
float err2 = ww.r * d2.r * d2.r + ww.g * d2.g * d2.g + ww.b * d2.b * d2.b + ww.a * d2.a * d2.a;
float4 d3 = weightedColors[7] - pp;
float err3 = wr * d3.r * d3.r + wg * d3.g * d3.g + wb * d3.b * d3.b + wa * d3.a * d3.a;
float err3 = ww.r * d3.r * d3.r + ww.g * d3.g * d3.g + ww.b * d3.b * d3.b + ww.a * d3.a * d3.a;

best_err = min(best_err, min(min(min(err0, err1), err2), err3));
best_sel = select(best_sel, 4, best_err == err0);
Expand All @@ -869,13 +859,13 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
float4 pp = float4(pPixels[i]);

float4 d0 = weightedColors[0] - pp;
float err0 = wr * d0.r * d0.r + wg * d0.g * d0.g + wb * d0.b * d0.b + wa * d0.a * d0.a;
float err0 = ww.r * d0.r * d0.r + ww.g * d0.g * d0.g + ww.b * d0.b * d0.b + ww.a * d0.a * d0.a;
float4 d1 = weightedColors[1] - pp;
float err1 = wr * d1.r * d1.r + wg * d1.g * d1.g + wb * d1.b * d1.b + wa * d1.a * d1.a;
float err1 = ww.r * d1.r * d1.r + ww.g * d1.g * d1.g + ww.b * d1.b * d1.b + ww.a * d1.a * d1.a;
float4 d2 = weightedColors[2] - pp;
float err2 = wr * d2.r * d2.r + wg * d2.g * d2.g + wb * d2.b * d2.b + wa * d2.a * d2.a;
float err2 = ww.r * d2.r * d2.r + ww.g * d2.g * d2.g + ww.b * d2.b * d2.b + ww.a * d2.a * d2.a;
float4 d3 = weightedColors[3] - pp;
float err3 = wr * d3.r * d3.r + wg * d3.g * d3.g + wb * d3.b * d3.b + wa * d3.a * d3.a;
float err3 = ww.r * d3.r * d3.r + ww.g * d3.g * d3.g + ww.b * d3.b * d3.b + ww.a * d3.a * d3.a;

float best_err = min(min(min(err0, err1), err2), err3);
int best_sel = select(0, 1, best_err == err1);
Expand All @@ -891,8 +881,8 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
else
{
// perceptual
wg *= pr_weight;
wb *= pb_weight;
ww.g *= pr_weight;
ww.b *= pb_weight;

float3 weightedColorsYCrCb[16];
for (uint32_t i = 0; i < N; i++)
Expand All @@ -918,7 +908,7 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
{
float3 d = ycrcb - weightedColorsYCrCb[j];
float da = pp.a - weightedColors[j].a;
float err = (wr * d.x * d.x) + (wg * d.y * d.y) + (wb * d.z * d.z) + (wa * da * da);
float err = (ww.r * d.x * d.x) + (ww.g * d.y * d.y) + (ww.b * d.z * d.z) + (ww.a * da * da);
if (err < best_err)
{
best_err = err;
Expand All @@ -944,7 +934,7 @@ static uint32_t evaluate_solution(const uchar4 pLow, const uchar4 pHigh, const t
for (uint32_t j = 0; j < N; j++)
{
float3 d = ycrcb - weightedColorsYCrCb[j];
float err = (wr * d.x * d.x) + (wg * d.y * d.y) + (wb * d.z * d.z);
float err = (ww.r * d.x * d.x) + (ww.g * d.y * d.y) + (ww.b * d.z * d.z);
if (err < best_err)
{
best_err = err;
Expand Down
2 changes: 1 addition & 1 deletion src/testmain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#endif

const bool kDoCapture = false;
const int kQuality = 0;
const int kQuality = 3;
const int kRunCount = kDoCapture ? 1 : 8;
const bool kAlwaysPerceptual = false;
#ifdef _MSC_VER
Expand Down

0 comments on commit 7c4d3ad

Please sign in to comment.