diff --git a/src/core/ldpc.cpp b/src/core/ldpc.cpp index e359822..234e9f3 100644 --- a/src/core/ldpc.cpp +++ b/src/core/ldpc.cpp @@ -5,7 +5,7 @@ namespace ldpc { ldpc_code::ldpc_code(const std::string &pcFileName) - : mMaxDC(0), + : mMaxDegree(0), mH(), mG() { @@ -43,7 +43,8 @@ namespace ldpc std::string line; int skipLines = 0; - if (!infile.good()) throw std::runtime_error("can not open file for reading"); + if (!infile.good()) + throw std::runtime_error("can not open file for reading"); while (getline(infile, line)) { @@ -53,15 +54,17 @@ namespace ldpc { int index; auto token = line.substr(0, i); - std::istringstream record(line.substr(i+1)); + std::istringstream record(line.substr(i + 1)); if (token.find("puncture") != std::string::npos) { - while (record >> index) mPuncture.push_back(index); - } + while (record >> index) + mPuncture.push_back(index); + } else if (token.find("shorten") != std::string::npos) { - while (record >> index) mShorten.push_back(index); + while (record >> index) + mShorten.push_back(index); } ++skipLines; @@ -76,17 +79,22 @@ namespace ldpc mH.read_from_file(pcFileName, skipLines); - // maximum check node degree - auto tmp = std::max_element(mH.row_neighbor().begin(), mH.row_neighbor().end(), [](const auto &a, const auto &b) { return (a.size() < b.size()); }); - mMaxDC = tmp->size(); + // maximum node degree + auto cd = std::max_element(mH.row_neighbor().begin(), mH.row_neighbor().end(), + [](const auto &a, const auto &b) { return (a.size() < b.size()); }); + auto vd = std::max_element(mH.col_neighbor().begin(), mH.col_neighbor().end(), + [](const auto &a, const auto &b) { return (a.size() < b.size()); }); + mMaxDegree = std::max(cd->size(), vd->size()); // position of transmitted bits for (int i = 0; i < nc(); i++) { auto tmp = std::find(mShorten.cbegin(), mShorten.cend(), i); - if (tmp != mShorten.cend()) continue; // skip if current index shortened + if (tmp != mShorten.cend()) + continue; // skip if current index shortened tmp = std::find(mPuncture.cbegin(), mPuncture.cend(), i); - if (tmp != mPuncture.cend()) continue; // skip if current index punctured + if (tmp != mPuncture.cend()) + continue; // skip if current index punctured mBitPos.push_back(i); } @@ -111,7 +119,7 @@ namespace ldpc os << "K : " << code.kc() << "\n"; os << "NNZ : " << code.nnz() << "\n"; //os << "Rank: " << code.mRank << "\n"; - //os << "max dc : " << code.max_dc() << "\n"; + //os << "max dc : " << code.max_degree() << "\n"; os << "puncture[" << code.puncture().size() << "] : " << code.puncture() << "\n"; os << "shorten[" << code.shorten().size() << "] : " << code.shorten() << "\n"; os << "Rate : " << rate << "\n"; diff --git a/src/core/ldpc.h b/src/core/ldpc.h index e943395..d3e205d 100644 --- a/src/core/ldpc.h +++ b/src/core/ldpc.h @@ -61,8 +61,8 @@ namespace ldpc const vec_int &puncture() const { return mPuncture; }; // Array of shorten indices const vec_int &shorten() const { return mShorten; }; - // Maximum check node degree - int max_dc() const { return mMaxDC; }; + // Maximum node degree + int max_degree() const { return mMaxDegree; }; // Index position of transmitted bits const vec_int &bit_pos() const { return mBitPos; } // Parity-check matrix @@ -73,7 +73,7 @@ namespace ldpc private: vec_int mPuncture; /* array pf punctured bit indices */ vec_int mShorten; /* array of shortened bit indices */ - int mMaxDC; + int mMaxDegree; // position of transmitted bits, i.e. puncture/shorten exluded vec_int mBitPos; diff --git a/src/decoding/decoder.cpp b/src/decoding/decoder.cpp index 8d2563a..828943f 100644 --- a/src/decoding/decoder.cpp +++ b/src/decoding/decoder.cpp @@ -83,8 +83,111 @@ namespace ldpc { } - int ldpc_decoder_bec::decode() + int ldpc_decoder_bec::decode() { + return 0; + } + + int ldpc_decoder_bec::decode(const vec_bits_t &channelInput) + { + auto &edges = mLdpcCode->H().nz_entry(); + + //initialize + for (int i = 0; i < mLdpcCode->nnz(); ++i) + { + mLv2c[i] = mLLRIn[edges[i].colIndex]; + } + + u32 I = 0; + while (I < mDecoderParam.iterations) + { + // CN update + for (int i = 0; i < mLdpcCode->mc(); ++i) + { + auto cw = mLdpcCode->H().row_neighbor()[i].size(); + auto &cn = mLdpcCode->H().row_neighbor()[i]; + mExMsgF[0] = mLv2c[cn[0].edgeIndex]; + mExMsgB[cw - 1] = mLv2c[cn[cw - 1].edgeIndex]; + for (u64 j = 1; j < cw; ++j) + { + mExMsgF[j] = cn_update(mExMsgF[j - 1], mLv2c[cn[j].edgeIndex]); + mExMsgB[cw - 1 - j] = cn_update(mExMsgB[cw - j], mLv2c[cn[cw - j - 1].edgeIndex]); + } + + mLc2v[cn[0].edgeIndex] = mExMsgB[1]; + mLc2v[cn[cw - 1].edgeIndex] = mExMsgF[cw - 2]; + for (u64 j = 1; j < cw - 1; ++j) + { + mLc2v[cn[j].edgeIndex] = cn_update(mExMsgF[j - 1], mExMsgB[j + 1]); + } + } + + // VN update + for (int i = 0; i < mLdpcCode->nc(); ++i) + { + // id channel output is no erasure + // propagate output + if (mLLRIn[i] != ERASURE) + { + auto &vn = mLdpcCode->H().col_neighbor()[i]; //neighbours of VN + for (const auto &hi : vn) + { + mLv2c[hi.edgeIndex] = channelInput[i].value; + } + + mLLROut[i] = channelInput[i].value; + mCO[i] = channelInput[i]; + } + else // channel output is erasure + { + auto vw = mLdpcCode->H().col_neighbor()[i].size(); + auto &vn = mLdpcCode->H().col_neighbor()[i]; + + mExMsgF[0] = mLc2v[vn[0].edgeIndex]; + mExMsgB[vw - 1] = mLc2v[vn[vw - 1].edgeIndex]; + for (u64 j = 1; j < vw; ++j) + { + mExMsgF[j] = vn_update(mExMsgF[j - 1], mLc2v[vn[j].edgeIndex], channelInput[i]); + mExMsgB[vw - 1 - j] = vn_update(mExMsgB[vw - j], mLc2v[vn[vw - j - 1].edgeIndex], channelInput[i]); + } + + mLv2c[vn[0].edgeIndex] = mExMsgB[1]; + mLv2c[vn[vw - 1].edgeIndex] = mExMsgF[vw - 2]; + for (u64 j = 1; j < vw - 1; ++j) + { + mLv2c[vn[j].edgeIndex] = vn_update(mExMsgF[j - 1], mExMsgB[j + 1], channelInput[i]); + } + + // final decision + mLLROut[i] = mExMsgF[vw - 1]; //mExMsgB[0] + // if all incoming messages are erasures set the wrong bit + mCO[i] = (mLLROut[i] == ERASURE) ? -channelInput[i] : channelInput[i]; + } + } + + if (mDecoderParam.earlyTerm) + { + // stop decoding when no erasures are left + bool erasure_found = false; + for (auto llr : mLLROut) + { + if (llr == ERASURE) + { + erasure_found = true; + break; + } + } + + if (!erasure_found) + { + break; + } + } + + ++I; + } + + return I; } } // namespace ldpc diff --git a/src/decoding/decoder.h b/src/decoding/decoder.h index 7ead4b0..ef88ca6 100644 --- a/src/decoding/decoder.h +++ b/src/decoding/decoder.h @@ -34,14 +34,14 @@ namespace ldpc mCNApprox(ldpc::jacobian), mCO(code->nc()), mLv2c(code->nnz()), mLc2v(code->nnz()), - mExMsgF(code->max_dc()), mExMsgB(code->max_dc()), + mExMsgF(code->max_degree()), mExMsgB(code->max_degree()), mLLRIn(code->nc()), mLLROut(code->nc()) { set_param(decoderParam); } virtual ~ldpc_decoder_base() = default; - virtual int decode(); + virtual int decode() { return 0; } // Verifies whether mCO is a codeword or not bool is_codeword() @@ -138,5 +138,20 @@ namespace ldpc virtual ~ldpc_decoder_bec() = default; int decode() override; + int decode(const vec_bits_t& channelInput); + + // BEC Decoder VN update + // if neither of the values equals the channel input then the output is an erasure + static inline constexpr u8 vn_update(const u8 l, const u8 r, const bits_t xi) + { + return ((xi.value == l) || (xi.value == r)) ? xi.value : ERASURE; + } + + // BEC Decoder CN update + // if any value is an erasure then the output is an erasure + static inline constexpr u8 cn_update(const u8 l, const u8 r) + { + return ((l == ERASURE) || (r == ERASURE)) ? ERASURE : (bits_t(l) + bits_t(r)).value; + } }; } // namespace ldpc diff --git a/src/sim/channel.cpp b/src/sim/channel.cpp index 2685d59..d033ef9 100644 --- a/src/sim/channel.cpp +++ b/src/sim/channel.cpp @@ -18,6 +18,8 @@ namespace ldpc void channel::encode_and_map() {} void channel::simulate() {} void channel::calculate_llrs() {} + int channel::decode() { return 0; } + const vec_bits_t &channel::estimate() const { return mCodeWord; } channel_awgn::channel_awgn(const std::shared_ptr &code, const decoder_param &decoderParams, diff --git a/src/sim/channel.h b/src/sim/channel.h index 93dca6d..de50a47 100644 --- a/src/sim/channel.h +++ b/src/sim/channel.h @@ -225,7 +225,7 @@ namespace ldpc */ int decode() override { - return mLdpcDecoder->decode(); + return mLdpcDecoder->decode(mCodeWord); } /** diff --git a/src/sim/ldpcsim.cpp b/src/sim/ldpcsim.cpp index 84b0345..4cf38c2 100644 --- a/src/sim/ldpcsim.cpp +++ b/src/sim/ldpcsim.cpp @@ -55,6 +55,19 @@ namespace ldpc ) ); } + else if (mChannelParams.type == std::string("BEC")) + { + mChannel.push_back( + std::make_shared( + channel_bec( + mLdpcCode, + mDecoderParams, + mChannelParams.seed + i, + 0. + ) + ) + ); + } else { throw std::runtime_error("No channel selected."); @@ -63,7 +76,8 @@ namespace ldpc } catch (std::exception &e) { - std::cout << "Error: ldpc_sim::ldpc_sim() " << e.what() << "\n"; + std::cout << "Error: ldpc_sim::ldpc_sim() " << e.what() << std::endl; + exit(EXIT_FAILURE); } } @@ -100,7 +114,7 @@ namespace ldpc auto maxFrames = mSimulationParams.maxFrames; std::string xValType = "SNR"; - if (mChannelParams.type == std::string("BSC")) + if (mChannelParams.type == std::string("BSC") || mChannelParams.type == std::string("BEC")) { xValType = "EPS"; // reverse the epsilon values, since we should start at the worst