Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly mask rows and columns when fetching balanced pixels as dense matrices #248

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 114 additions & 23 deletions src/libhictk/balancing/include/hictk/balancing/impl/weights_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,28 @@
#pragma once

#include <fmt/format.h>
#include <parallel_hashmap/phmap.h>

#include <algorithm>
#include <array>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string_view>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>

#include "hictk/common.hpp"
#include "hictk/pixel.hpp"

namespace hictk::balancing {

inline Weights::Weights(std::vector<double> weights, Type type) noexcept
: _weights(std::move(weights)), _type(type) {
: _weights(std::make_shared<WeightVect>(std::move(weights))), _type(type) {
assert(_type != Type::INFER && _type != Type::UNKNOWN);
}

Expand All @@ -35,14 +39,56 @@ inline Weights::Weights(std::vector<double> weights, std::string_view name)
}
}

inline Weights::operator bool() const noexcept { return !_weights.empty(); }
inline Weights::Weights(double weight, std::size_t size, Type type) noexcept
: _weights(ConstWeight{weight, size}), _type(type) {
assert(_type != Type::INFER && _type != Type::UNKNOWN);
}

inline Weights::Weights(double weight, std::size_t size, std::string_view name)
: Weights(weight, size, Weights::infer_type(name)) {
assert(_type != Type::INFER);
if (_type == Type::UNKNOWN) {
throw std::runtime_error(
fmt::format(FMT_STRING("unable to infer type for \"{}\" weights"), name));
}
}
inline Weights::Weights(std::variant<ConstWeight, WeightVectPtr> weights, Type type_) noexcept
: _weights(std::move(weights)), _type(type_) {}

inline Weights::operator bool() const noexcept {
if (is_constant()) {
return std::get<ConstWeight>(_weights).size != 0;
}

const auto &weights = std::get<WeightVectPtr>(_weights);
return !!weights && !weights->empty();
}

inline double Weights::at(std::size_t i) const {
if (is_constant()) {
const auto &[w, size] = std::get<ConstWeight>(_weights);

if (i >= size) {
throw std::out_of_range("Weights::at()");
}

inline double Weights::operator[](std::size_t i) const noexcept {
assert(i < _weights.size());
return _weights[i];
return w;
}

return std::get<WeightVectPtr>(_weights)->at(i);
}

inline double Weights::at(std::size_t i) const { return _weights.at(i); }
inline double Weights::at(std::size_t i, Type type_) const {
if (HICTK_UNLIKELY(type_ != Type::MULTIPLICATIVE && type_ != Type::DIVISIVE)) {
throw std::logic_error("Type should be Type::MULTIPLICATIVE or Type::DIVISIVE");
}

if (type_ == type()) {
return at(i);
}

return 1.0 / at(i);
}

template <typename N>
inline ThinPixel<N> Weights::balance(ThinPixel<N> p) const {
Expand All @@ -59,8 +105,8 @@ inline Pixel<N> Weights::balance(Pixel<N> p) const {
template <typename N1, typename N2>
inline N1 Weights::balance(std::uint64_t bin1_id, std::uint64_t bin2_id, N2 count) const {
assert(std::is_floating_point_v<N1>);
const auto w1 = _weights[bin1_id];
const auto w2 = _weights[bin2_id];
const auto w1 = at(conditional_static_cast<std::size_t>(bin1_id));
const auto w2 = at(conditional_static_cast<std::size_t>(bin2_id));

auto count_ = conditional_static_cast<double>(count);

Expand All @@ -73,28 +119,43 @@ inline N1 Weights::balance(std::uint64_t bin1_id, std::uint64_t bin2_id, N2 coun
return conditional_static_cast<N1>(count_);
}

inline const std::vector<double> Weights::operator()(Type type_) const {
if (type_ != Type::MULTIPLICATIVE && type_ != Type::DIVISIVE) {
inline Weights Weights::operator()(Type type_) const {
if (HICTK_UNLIKELY(type_ != Type::MULTIPLICATIVE && type_ != Type::DIVISIVE)) {
throw std::logic_error("Type should be Type::MULTIPLICATIVE or Type::DIVISIVE");
}

if (type_ == type()) {
return _weights;
return *this;
}

if (is_constant()) {
const auto &[w, size] = std::get<ConstWeight>(_weights);
return {ConstWeight{1.0 / w, size}, type_};
}

auto weights = _weights;
std::transform(weights.begin(), weights.end(), weights.begin(),
const auto &buff = std::get<WeightVectPtr>(_weights);
assert(!!buff);

auto weights = std::make_shared<WeightVect>(buff->size());
std::transform(buff->begin(), buff->end(), weights->begin(),
[](const auto n) { return 1.0 / n; });

return weights;
return {std::move(weights), type_};
}

constexpr auto Weights::type() const noexcept -> Type { return _type; }

inline std::size_t Weights::size() const noexcept { return _weights.size(); }
inline std::size_t Weights::size() const noexcept {
if (is_constant()) {
return std::get<ConstWeight>(_weights).size;
}
const auto &weights = std::get<WeightVectPtr>(_weights);
assert(!!weights);
return weights->size();
}

inline auto Weights::infer_type(std::string_view name) -> Type {
const static phmap::flat_hash_map<std::string_view, Type> mappings{
constexpr std::array<std::pair<std::string_view, Type>, 14> mappings{
{{"VC", Type::DIVISIVE},
{"INTER_VC", Type::DIVISIVE},
{"GW_VC", Type::DIVISIVE},
Expand All @@ -110,26 +171,56 @@ inline auto Weights::infer_type(std::string_view name) -> Type {
{"GW_ICE", Type::MULTIPLICATIVE},
{"weight", Type::MULTIPLICATIVE}}};

auto it = mappings.find(name);
auto it = std::find_if(mappings.begin(), mappings.end(),
[&](const auto &p) { return p.first == name; });
if (it == mappings.end()) {
return Weights::Type::UNKNOWN;
}
return it->second;
}

inline void Weights::rescale(double scaling_factor) noexcept {
std::transform(_weights.begin(), _weights.end(), _weights.begin(),
[&](auto w) { return w * std::sqrt(scaling_factor); });
if (is_constant()) {
auto &w = std::get<ConstWeight>(_weights).w;
w *= std::sqrt(scaling_factor);
} else {
auto &weights = std::get<WeightVectPtr>(_weights);
assert(!!weights);
std::transform(weights->begin(), weights->end(), weights->begin(),
[&](auto w) { return w * std::sqrt(scaling_factor); });
}
}

inline void Weights::rescale(const std::vector<double> &scaling_factors,
const std::vector<std::uint64_t> &offsets) noexcept {
const std::vector<std::uint64_t> &offsets) {
if (is_constant()) {
if (scaling_factors.size() == 1) {
rescale(scaling_factors.front());
return;
}
throw std::runtime_error(
"rescaling ConstWeight with multiple scaling factors is not supported");
}

const auto &weights = std::get<WeightVectPtr>(_weights);
assert(!!weights);

for (std::size_t i = 0; i < scaling_factors.size(); ++i) {
auto first = _weights.begin() + std::ptrdiff_t(offsets[i]);
auto last = _weights.begin() + std::ptrdiff_t(offsets[i + 1]);
auto first = weights->begin() + std::ptrdiff_t(offsets[i]);
auto last = weights->begin() + std::ptrdiff_t(offsets[i + 1]);
std::transform(first, last, first,
[s = scaling_factors[i]](const double w) { return w * std::sqrt(s); });
}
}

inline bool Weights::is_constant() const noexcept {
return std::holds_alternative<ConstWeight>(_weights);
}

inline Weights::iterator::iterator(std::vector<double>::const_iterator it)
: _it(std::move(it)), _i(0) {}

inline Weights::iterator::iterator(const hictk::balancing::Weights::ConstWeight &weight)
: _it(&weight), _i(0) {}

} // namespace hictk::balancing
33 changes: 28 additions & 5 deletions src/libhictk/balancing/include/hictk/balancing/weights.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <memory>
#include <string>
#include <string_view>
#include <variant>
#include <vector>

#include "hictk/pixel.hpp"
Expand All @@ -21,39 +22,61 @@ class Weights {
public:
enum class Type { INFER, DIVISIVE, MULTIPLICATIVE, UNKNOWN };

class iterator;

private:
std::vector<double> _weights{};
struct ConstWeight {
double w{};
std::size_t size{};
};
using WeightVect = std::vector<double>;
using WeightVectPtr = std::shared_ptr<WeightVect>;

std::variant<ConstWeight, WeightVectPtr> _weights{ConstWeight{}};
Type _type{};

public:
Weights() = default;
Weights(std::vector<double> weights, Type type) noexcept;
Weights(std::vector<double> weights, std::string_view name);
Weights(double weight, std::size_t size, Type type) noexcept;
Weights(double weight, std::size_t size, std::string_view name);

[[nodiscard]] explicit operator bool() const noexcept;
[[nodiscard]] double operator[](std::size_t i) const noexcept;

[[nodiscard]] double at(std::size_t i) const;
[[nodiscard]] double at(std::size_t i, Type type_) const;

template <typename N>
[[nodiscard]] hictk::ThinPixel<N> balance(hictk::ThinPixel<N> p) const;
template <typename N>
[[nodiscard]] hictk::Pixel<N> balance(hictk::Pixel<N> p) const;

[[nodiscard]] const std::vector<double> operator()(Type type_) const;
// [[nodiscard]] const std::vector<double>& operator()() const noexcept;
[[nodiscard]] Weights operator()(Type type_) const;
[[nodiscard]] constexpr auto type() const noexcept -> Type;
[[nodiscard]] std::size_t size() const noexcept;

[[nodiscard]] static auto infer_type(std::string_view name) -> Type;

void rescale(double scaling_factor) noexcept;
void rescale(const std::vector<double>& scaling_factors,
const std::vector<std::uint64_t>& offsets) noexcept;
const std::vector<std::uint64_t>& offsets);

class iterator {
std::variant<std::vector<double>::const_iterator, const ConstWeight*> _it{nullptr};
std::ptrdiff_t _i{std::numeric_limits<std::ptrdiff_t>::max()};

explicit iterator(std::vector<double>::const_iterator it);
explicit iterator(const ConstWeight& weight);

public:
};

private:
Weights(std::variant<ConstWeight, WeightVectPtr> weights, Type type_) noexcept;
template <typename N1, typename N2>
[[nodiscard]] N1 balance(std::uint64_t bin1_id, std::uint64_t bin2_id, N2 count) const;
[[nodiscard]] bool is_constant() const noexcept;
};

using WeightMap = phmap::flat_hash_map<std::string, std::shared_ptr<const Weights>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ inline PixelSelector PixelSelector::fetch(PixelCoordinates coord1, PixelCoordina
_weights};
}

inline std::shared_ptr<const balancing::Weights> PixelSelector::weights() const noexcept {
return _weights;
}

template <typename N>
inline PixelSelector::iterator<N>::iterator(
const Dataset &pixels_bin1_id, const Dataset &pixels_bin2_id, const Dataset &pixels_count,
Expand Down
2 changes: 2 additions & 0 deletions src/libhictk/cooler/include/hictk/cooler/pixel_selector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class PixelSelector {

[[nodiscard]] PixelSelector fetch(PixelCoordinates coord1, PixelCoordinates coord2) const;

[[nodiscard]] std::shared_ptr<const balancing::Weights> weights() const noexcept;

public:
template <typename N>
class iterator {
Expand Down
12 changes: 12 additions & 0 deletions src/libhictk/hic/include/hictk/hic/impl/hic_file_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ inline std::shared_ptr<const internal::HiCFooter> File::get_footer(
auto weights1 = _weight_cache->find_or_emplace(chrom1, norm);
auto weights2 = _weight_cache->find_or_emplace(chrom2, norm);

if (!(*weights1) && norm == balancing::Method::NONE()) {
const auto num_bins = (chrom1.size() + resolution - 1) / resolution;
weights1 = std::make_shared<balancing::Weights>(std::vector<double>(num_bins, 1.0),
balancing::Weights::Type::DIVISIVE);
}

if (!(*weights2) && norm == balancing::Method::NONE()) {
const auto num_bins = (chrom2.size() + resolution - 1) / resolution;
weights2 = std::make_shared<balancing::Weights>(std::vector<double>(num_bins, 1.0),
balancing::Weights::Type::DIVISIVE);
}

auto [node, _] = _footers.emplace(_fs->read_footer(chrom1.id(), chrom2.id(), matrix_type, norm,
unit, resolution, weights1, weights2));

Expand Down
10 changes: 5 additions & 5 deletions src/libhictk/hic/include/hictk/hic/impl/pixel_selector_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,18 +735,18 @@ inline const BinTable &PixelSelectorAll::bins() const noexcept {
}
inline std::shared_ptr<const BinTable> PixelSelectorAll::bins_ptr() const noexcept { return _bins; }

inline std::vector<double> PixelSelectorAll::weights() const {
std::vector<double> weights_{};
weights_.reserve(bins().size());
inline balancing::Weights PixelSelectorAll::weights() const {
std::vector<double> weights_(bins().size(), std::numeric_limits<double>::quiet_NaN());

std::for_each(_selectors.begin(), _selectors.end(), [&](const PixelSelector &sel) {
if (sel.is_intra()) {
const auto chrom_weights = sel.weights1()(balancing::Weights::Type::DIVISIVE);
weights_.insert(weights_.end(), chrom_weights.begin(), chrom_weights.end());
const auto offset = static_cast<std::ptrdiff_t>(bins().at(sel.chrom1()).id());
std::copy(chrom_weights.begin(), chrom_weights.end(), weights_.begin() + offset);
}
});

return weights_;
return {std::move(weights_), balancing::Weights::Type::DIVISIVE};
}

template <typename N>
Expand Down
2 changes: 1 addition & 1 deletion src/libhictk/hic/include/hictk/hic/pixel_selector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class PixelSelectorAll {
[[nodiscard]] std::uint32_t resolution() const noexcept;
[[nodiscard]] const BinTable &bins() const noexcept;
[[nodiscard]] std::shared_ptr<const BinTable> bins_ptr() const noexcept;
[[nodiscard]] std::vector<double> weights() const;
[[nodiscard]] balancing::Weights weights() const;

template <typename N>
class iterator {
Expand Down
2 changes: 1 addition & 1 deletion src/libhictk/transformers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ target_sources(
"${CMAKE_CURRENT_SOURCE_DIR}/include")
target_include_directories(transformers INTERFACE "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>"
"$<INSTALL_INTERFACE:include>")
target_link_libraries(transformers INTERFACE hictk::cooler hictk::hic)
target_link_libraries(transformers INTERFACE hictk::cooler hictk::file hictk::hic)

if(HICTK_WITH_ARROW_SHARED)
target_link_system_libraries(transformers INTERFACE "$<$<BOOL:${HICTK_WITH_ARROW}>:Arrow::arrow_shared>")
Expand Down
Loading