Skip to content

Commit

Permalink
Fix RS decode when c != 1
Browse files Browse the repository at this point in the history
Fixes #215
  • Loading branch information
mhostetter committed Jan 7, 2022
1 parent b8c4095 commit 69bae7e
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions galois/_codes/_reed_solomon.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def decode(self, codeword, errors=False):
syndrome = codeword.view(self.field) @ self.H[:,-ns:].T

if self.field.ufunc_mode != "python-calculate":
dec_codeword = self._decode_jit(codeword.astype(np.int64), syndrome.astype(np.int64), self.t, int(self.field.primitive_element), self._add_jit, self._subtract_jit, self._multiply_jit, self._reciprocal_jit, self._power_jit, self._berlekamp_massey_jit, self._poly_roots_jit, self._poly_eval_jit, self._convolve_jit, self.field.characteristic, self.field.degree, self.field._irreducible_poly_int)
dec_codeword = self._decode_jit(codeword.astype(np.int64), syndrome.astype(np.int64), self.c, self.t, int(self.field.primitive_element), self._add_jit, self._subtract_jit, self._multiply_jit, self._reciprocal_jit, self._power_jit, self._berlekamp_massey_jit, self._poly_roots_jit, self._poly_eval_jit, self._convolve_jit, self.field.characteristic, self.field.degree, self.field._irreducible_poly_int)
N_errors = dec_codeword[:, -1]

if self.systematic:
Expand Down Expand Up @@ -615,9 +615,9 @@ def is_narrow_sense(self):
# JIT-compiled implementation of the specified functions
###############################################################################

DECODE_CALCULATE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, FieldClass._BINARY_CALCULATE_SIG, FieldClass._BINARY_CALCULATE_SIG, FieldClass._BINARY_CALCULATE_SIG, FieldClass._UNARY_CALCULATE_SIG, FieldClass._BINARY_CALCULATE_SIG, _lfsr.BERLEKAMP_MASSEY_CALCULATE_SIG, FieldClass._POLY_ROOTS_CALCULATE_SIG, FieldClass._POLY_EVALUATE_CALCULATE_SIG, FieldClass._CONVOLVE_CALCULATE_SIG, int64, int64, int64))
DECODE_CALCULATE_SIG = numba.types.FunctionType(int64[:,:](int64[:,:], int64[:,:], int64, int64, int64, FieldClass._BINARY_CALCULATE_SIG, FieldClass._BINARY_CALCULATE_SIG, FieldClass._BINARY_CALCULATE_SIG, FieldClass._UNARY_CALCULATE_SIG, FieldClass._BINARY_CALCULATE_SIG, _lfsr.BERLEKAMP_MASSEY_CALCULATE_SIG, FieldClass._POLY_ROOTS_CALCULATE_SIG, FieldClass._POLY_EVALUATE_CALCULATE_SIG, FieldClass._CONVOLVE_CALCULATE_SIG, int64, int64, int64))

def decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, BERLEKAMP_MASSEY, POLY_ROOTS, POLY_EVAL, CONVOLVE, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY): # pragma: no cover
def decode_calculate(codeword, syndrome, c, t, primitive_element, ADD, SUBTRACT, MULTIPLY, RECIPROCAL, POWER, BERLEKAMP_MASSEY, POLY_ROOTS, POLY_EVAL, CONVOLVE, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY): # pragma: no cover
"""
References
----------
Expand Down Expand Up @@ -671,7 +671,7 @@ def decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, MU
continue

# Compute σ'(x)
sigma_prime = np.zeros(v, dtype=np.int64)
sigma_prime = np.zeros(v, dtype=dtype)
for j in range(v):
degree = v - j
sigma_prime[j] = MULTIPLY(degree % CHARACTERISTIC, sigma[j], *args) # Scalar multiplication
Expand All @@ -680,11 +680,14 @@ def decode_calculate(codeword, syndrome, t, primitive_element, ADD, SUBTRACT, MU
# with degree v-1
Z0 = CONVOLVE(sigma[-v:], syndrome[i,0:v][::-1], ADD, MULTIPLY, *args)[-v:]

# The error value δi = -Z0(βi^-1) / σ'(βi^-1)
# The error value δi = -1 * βi^(1-c) * Z0(βi^-1) / σ'(βi^-1)
for j in range(v):
beta_i = POWER(beta_inv[j], c - 1, *args)
Z0_i = POLY_EVAL(Z0, np.array([beta_inv[j]], dtype=dtype), ADD, MULTIPLY, *args)[0] # NOTE: poly_eval() expects a 1-D array of values
sigma_prime_i = POLY_EVAL(sigma_prime, np.array([beta_inv[j]], dtype=dtype), ADD, MULTIPLY, *args)[0] # NOTE: poly_eval() expects a 1-D array of values
delta_i = MULTIPLY(SUBTRACT(0, Z0_i, *args), RECIPROCAL(sigma_prime_i, *args), *args)
delta_i = MULTIPLY(beta_i, Z0_i, *args)
delta_i = MULTIPLY(delta_i, RECIPROCAL(sigma_prime_i, *args), *args)
delta_i = SUBTRACT(0, delta_i, *args)
dec_codeword[i, n - 1 - error_locations[j]] = SUBTRACT(dec_codeword[i, n - 1 - error_locations[j]], delta_i, *args)

dec_codeword[i,-1] = v # The number of corrected errors
Expand Down

0 comments on commit 69bae7e

Please sign in to comment.