Skip to content

Commit

Permalink
[lang] Support f64 cpu sparse linear solver (taichi-dev#6657)
Browse files Browse the repository at this point in the history
  • Loading branch information
FantasyVR authored and quadpixels committed May 13, 2023
1 parent 7c6e4cc commit 9da4488
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 61 deletions.
148 changes: 118 additions & 30 deletions taichi/program/sparse_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,80 @@

#include <unordered_map>

#define MAKE_SOLVER(dt, type, order) \
{ \
{#dt, #type, #order}, []() -> std::unique_ptr<SparseSolver> { \
using T = Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
Eigen::order##Ordering<int>>; \
return std::make_unique< \
EigenSparseSolver<T, Eigen::SparseMatrix<dt>>>(); \
} \
}
namespace taichi::lang {
#define EIGEN_LLT_SOLVER_INSTANTIATION(dt, type, order) \
template class EigenSparseSolver< \
Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
Eigen::order##Ordering<int>>, \
Eigen::SparseMatrix<dt>>;
#define EIGEN_LU_SOLVER_INSTANTIATION(dt, type, order) \
template class EigenSparseSolver< \
Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
Eigen::order##Ordering<int>>, \
Eigen::SparseMatrix<dt>>;
// Explicit instantiation of EigenSparseSolver
EIGEN_LLT_SOLVER_INSTANTIATION(float32, LLT, AMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float32, LLT, COLAMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float32, LDLT, AMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float32, LDLT, COLAMD);
EIGEN_LU_SOLVER_INSTANTIATION(float32, LU, AMD);
EIGEN_LU_SOLVER_INSTANTIATION(float32, LU, COLAMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float64, LLT, AMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float64, LLT, COLAMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float64, LDLT, AMD);
EIGEN_LLT_SOLVER_INSTANTIATION(float64, LDLT, COLAMD);
EIGEN_LU_SOLVER_INSTANTIATION(float64, LU, AMD);
EIGEN_LU_SOLVER_INSTANTIATION(float64, LU, COLAMD);
} // namespace taichi::lang

#define INSTANTIATE_SOLVER(dt, type, order) \
using dt##type##order = \
// Explicit instantiation of the template class EigenSparseSolver::solve
#define EIGEN_LLT_SOLVE_INSTANTIATION(dt, type, order, df) \
using T##dt = Eigen::VectorX##df; \
using S##dt##type##order = \
Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
Eigen::order##Ordering<int>>; \
template void \
EigenSparseSolver<dt##type##order, Eigen::SparseMatrix<dt>>::solve_rf( \
Program *prog, const SparseMatrix &sm, const Ndarray &b, Ndarray &x);
template T##dt \
EigenSparseSolver<S##dt##type##order, Eigen::SparseMatrix<dt>>::solve( \
const T##dt &b);
#define EIGEN_LU_SOLVE_INSTANTIATION(dt, type, order, df) \
using LUT##dt = Eigen::VectorX##df; \
using LUS##dt##type##order = \
Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
Eigen::order##Ordering<int>>; \
template LUT##dt \
EigenSparseSolver<LUS##dt##type##order, Eigen::SparseMatrix<dt>>::solve( \
const LUT##dt &b);

// Explicit instantiation of the template class EigenSparseSolver::solve_rf
#define INSTANTIATE_LLT_SOLVE_RF(dt, type, order, df) \
using llt##dt##type##order = \
Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
Eigen::order##Ordering<int>>; \
template void EigenSparseSolver<llt##dt##type##order, \
Eigen::SparseMatrix<dt>>::solve_rf<df, \
dt>( \
Program * prog, const SparseMatrix &sm, const Ndarray &b, \
const Ndarray &x);

#define INSTANTIATE_LU_SOLVE_RF(dt, type, order, df) \
using lu##dt##type##order = \
Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
Eigen::order##Ordering<int>>; \
template void EigenSparseSolver<lu##dt##type##order, \
Eigen::SparseMatrix<dt>>::solve_rf<df, \
dt>( \
Program * prog, const SparseMatrix &sm, const Ndarray &b, \
const Ndarray &x);

#define MAKE_EIGEN_SOLVER(dt, type, order) \
std::make_unique<EigenSparseSolver##dt##type##order>()

#define MAKE_SOLVER(dt, type, order) \
{ \
{#dt, #type, #order}, []() -> std::unique_ptr<SparseSolver> { \
return MAKE_EIGEN_SOLVER(dt, type, order); \
} \
}

using Triplets = std::tuple<std::string, std::string, std::string>;
namespace {
Expand Down Expand Up @@ -70,32 +127,53 @@ void EigenSparseSolver<EigenSolver, EigenMatrix>::factorize(
}

template <class EigenSolver, class EigenMatrix>
Eigen::VectorXf EigenSparseSolver<EigenSolver, EigenMatrix>::solve(
const Eigen::Ref<const Eigen::VectorXf> &b) {
template <typename T>
T EigenSparseSolver<EigenSolver, EigenMatrix>::solve(const T &b) {
return solver_.solve(b);
}

EIGEN_LLT_SOLVE_INSTANTIATION(float32, LLT, AMD, f);
EIGEN_LLT_SOLVE_INSTANTIATION(float32, LLT, COLAMD, f);
EIGEN_LLT_SOLVE_INSTANTIATION(float32, LDLT, AMD, f);
EIGEN_LLT_SOLVE_INSTANTIATION(float32, LDLT, COLAMD, f);
EIGEN_LU_SOLVE_INSTANTIATION(float32, LU, AMD, f);
EIGEN_LU_SOLVE_INSTANTIATION(float32, LU, COLAMD, f);
EIGEN_LLT_SOLVE_INSTANTIATION(float64, LLT, AMD, d);
EIGEN_LLT_SOLVE_INSTANTIATION(float64, LLT, COLAMD, d);
EIGEN_LLT_SOLVE_INSTANTIATION(float64, LDLT, AMD, d);
EIGEN_LLT_SOLVE_INSTANTIATION(float64, LDLT, COLAMD, d);
EIGEN_LU_SOLVE_INSTANTIATION(float64, LU, AMD, d);
EIGEN_LU_SOLVE_INSTANTIATION(float64, LU, COLAMD, d);

template <class EigenSolver, class EigenMatrix>
bool EigenSparseSolver<EigenSolver, EigenMatrix>::info() {
return solver_.info() == Eigen::Success;
}

template <class EigenSolver, class EigenMatrix>
template <typename T, typename V>
void EigenSparseSolver<EigenSolver, EigenMatrix>::solve_rf(
Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) {
const Ndarray &x) {
size_t db = prog->get_ndarray_data_ptr_as_int(&b);
size_t dX = prog->get_ndarray_data_ptr_as_int(&x);
Eigen::Map<Eigen::VectorXf>((float *)dX, rows_) =
solver_.solve(Eigen::Map<Eigen::VectorXf>((float *)db, cols_));
Eigen::Map<T>((V *)dX, rows_) = solver_.solve(Eigen::Map<T>((V *)db, cols_));
}

INSTANTIATE_SOLVER(float32, LLT, COLAMD)
INSTANTIATE_SOLVER(float32, LDLT, COLAMD)
INSTANTIATE_SOLVER(float32, LLT, AMD)
INSTANTIATE_SOLVER(float32, LDLT, AMD)
INSTANTIATE_LLT_SOLVE_RF(float32, LLT, COLAMD, Eigen::VectorXf)
INSTANTIATE_LLT_SOLVE_RF(float32, LDLT, COLAMD, Eigen::VectorXf)
INSTANTIATE_LLT_SOLVE_RF(float32, LLT, AMD, Eigen::VectorXf)
INSTANTIATE_LLT_SOLVE_RF(float32, LDLT, AMD, Eigen::VectorXf)
INSTANTIATE_LU_SOLVE_RF(float32, LU, AMD, Eigen::VectorXf)
INSTANTIATE_LU_SOLVE_RF(float32, LU, COLAMD, Eigen::VectorXf)
INSTANTIATE_LLT_SOLVE_RF(float64, LLT, COLAMD, Eigen::VectorXd)
INSTANTIATE_LLT_SOLVE_RF(float64, LDLT, COLAMD, Eigen::VectorXd)
INSTANTIATE_LLT_SOLVE_RF(float64, LLT, AMD, Eigen::VectorXd)
INSTANTIATE_LLT_SOLVE_RF(float64, LDLT, AMD, Eigen::VectorXd)
INSTANTIATE_LU_SOLVE_RF(float64, LU, AMD, Eigen::VectorXd)
INSTANTIATE_LU_SOLVE_RF(float64, LU, COLAMD, Eigen::VectorXd)

CuSparseSolver::CuSparseSolver() {
#if defined(TI_WITH_CUDA)
Expand Down Expand Up @@ -184,7 +262,7 @@ void CuSparseSolver::factorize(const SparseMatrix &sm) {
void CuSparseSolver::solve_cu(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) {
const Ndarray &x) {
#ifdef TI_WITH_CUDA
cusparseHandle_t cusparseHandle = nullptr;
CUSPARSEDriver::get_instance().cpCreate(&cusparseHandle);
Expand Down Expand Up @@ -349,7 +427,7 @@ void CuSparseSolver::solve_cu(Program *prog,
void CuSparseSolver::solve_rf(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) {
const Ndarray &x) {
#if defined(TI_WITH_CUDA)
if (is_analyzed_ == false) {
analyze_pattern(sm);
Expand Down Expand Up @@ -383,8 +461,10 @@ std::unique_ptr<SparseSolver> make_sparse_solver(DataType dt,
using func_type = std::unique_ptr<SparseSolver> (*)();
static const std::unordered_map<key_type, func_type, key_hash>
solver_factory = {
MAKE_SOLVER(float32, LLT, AMD), MAKE_SOLVER(float32, LLT, COLAMD),
MAKE_SOLVER(float32, LDLT, AMD), MAKE_SOLVER(float32, LDLT, COLAMD)};
MAKE_SOLVER(float32, LLT, AMD), MAKE_SOLVER(float32, LLT, COLAMD),
MAKE_SOLVER(float32, LDLT, AMD), MAKE_SOLVER(float32, LDLT, COLAMD),
MAKE_SOLVER(float64, LLT, AMD), MAKE_SOLVER(float64, LLT, COLAMD),
MAKE_SOLVER(float64, LDLT, AMD), MAKE_SOLVER(float64, LDLT, COLAMD)};
static const std::unordered_map<std::string, std::string> dt_map = {
{"f32", "float32"}, {"f64", "float64"}};
auto it = dt_map.find(taichi::lang::data_type_name(dt));
Expand All @@ -397,9 +477,17 @@ std::unique_ptr<SparseSolver> make_sparse_solver(DataType dt,
auto solver_func = solver_factory.at(solver_key);
return solver_func();
} else if (solver_type == "LU") {
using EigenMatrix = Eigen::SparseMatrix<float32>;
using LU = Eigen::SparseLU<EigenMatrix>;
return std::make_unique<EigenSparseSolver<LU, EigenMatrix>>();
if (it->first == "f32") {
using EigenMatrix = Eigen::SparseMatrix<float32>;
using LU = Eigen::SparseLU<EigenMatrix>;
return std::make_unique<EigenSparseSolver<LU, EigenMatrix>>();
} else if (it->first == "f64") {
using EigenMatrix = Eigen::SparseMatrix<float64>;
using LU = Eigen::SparseLU<EigenMatrix>;
return std::make_unique<EigenSparseSolver<LU, EigenMatrix>>();
} else {
TI_ERROR("Not supported sparse solver data type: {}", it->second);
}
} else
TI_ERROR("Not supported sparse solver type: {}", solver_type);
}
Expand Down
58 changes: 34 additions & 24 deletions taichi/program/sparse_solver.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
#pragma once

#include "sparse_matrix.h"

#include "taichi/ir/type.h"
#include "taichi/rhi/cuda/cuda_driver.h"
#include "taichi/program/program.h"

#include "sparse_matrix.h"
#define DECLARE_EIGEN_LLT_SOLVER(dt, type, order) \
typedef EigenSparseSolver< \
Eigen::Simplicial##type<Eigen::SparseMatrix<dt>, Eigen::Lower, \
Eigen::order##Ordering<int>>, \
Eigen::SparseMatrix<dt>> \
EigenSparseSolver##dt##type##order;

#define DECLARE_EIGEN_LU_SOLVER(dt, type, order) \
typedef EigenSparseSolver<Eigen::Sparse##type<Eigen::SparseMatrix<dt>, \
Eigen::order##Ordering<int>>, \
Eigen::SparseMatrix<dt>> \
EigenSparseSolver##dt##type##order;

namespace taichi::lang {

Expand All @@ -25,15 +38,6 @@ class SparseSolver {
virtual bool compute(const SparseMatrix &sm) = 0;
virtual void analyze_pattern(const SparseMatrix &sm) = 0;
virtual void factorize(const SparseMatrix &sm) = 0;
virtual Eigen::VectorXf solve(const Eigen::Ref<const Eigen::VectorXf> &b) = 0;
virtual void solve_rf(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) = 0;
virtual void solve_cu(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) = 0;
virtual bool info() = 0;
};

Expand All @@ -47,21 +51,30 @@ class EigenSparseSolver : public SparseSolver {
bool compute(const SparseMatrix &sm) override;
void analyze_pattern(const SparseMatrix &sm) override;
void factorize(const SparseMatrix &sm) override;
Eigen::VectorXf solve(const Eigen::Ref<const Eigen::VectorXf> &b) override;
void solve_cu(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) override {
TI_NOT_IMPLEMENTED;
};
template <typename T>
T solve(const T &b);

template <typename T, typename V>
void solve_rf(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) override;

const Ndarray &x);
bool info() override;
};

DECLARE_EIGEN_LLT_SOLVER(float32, LLT, AMD);
DECLARE_EIGEN_LLT_SOLVER(float32, LLT, COLAMD);
DECLARE_EIGEN_LLT_SOLVER(float32, LDLT, AMD);
DECLARE_EIGEN_LLT_SOLVER(float32, LDLT, COLAMD);
DECLARE_EIGEN_LU_SOLVER(float32, LU, AMD);
DECLARE_EIGEN_LU_SOLVER(float32, LU, COLAMD);
DECLARE_EIGEN_LLT_SOLVER(float64, LLT, AMD);
DECLARE_EIGEN_LLT_SOLVER(float64, LLT, COLAMD);
DECLARE_EIGEN_LLT_SOLVER(float64, LDLT, AMD);
DECLARE_EIGEN_LLT_SOLVER(float64, LDLT, COLAMD);
DECLARE_EIGEN_LU_SOLVER(float64, LU, AMD);
DECLARE_EIGEN_LU_SOLVER(float64, LU, COLAMD);

class CuSparseSolver : public SparseSolver {
private:
csrcholInfo_t info_{nullptr};
Expand All @@ -81,17 +94,14 @@ class CuSparseSolver : public SparseSolver {
void analyze_pattern(const SparseMatrix &sm) override;

void factorize(const SparseMatrix &sm) override;
Eigen::VectorXf solve(const Eigen::Ref<const Eigen::VectorXf> &b) override {
TI_NOT_IMPLEMENTED;
};
void solve_cu(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) override;
const Ndarray &x);
void solve_rf(Program *prog,
const SparseMatrix &sm,
const Ndarray &b,
Ndarray &x) override;
const Ndarray &x);
bool info() override {
TI_NOT_IMPLEMENTED;
};
Expand Down
38 changes: 35 additions & 3 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,11 +1259,43 @@ void export_lang(py::module &m) {
.def("compute", &SparseSolver::compute)
.def("analyze_pattern", &SparseSolver::analyze_pattern)
.def("factorize", &SparseSolver::factorize)
.def("solve", &SparseSolver::solve)
.def("solve_cu", &SparseSolver::solve_cu)
.def("solve_rf", &SparseSolver::solve_rf)
.def("info", &SparseSolver::info);

#define REGISTER_EIGEN_SOLVER(dt, type, order, fd) \
py::class_<EigenSparseSolver##dt##type##order, SparseSolver>( \
m, "EigenSparseSolver" #dt #type #order) \
.def("compute", &EigenSparseSolver##dt##type##order::compute) \
.def("analyze_pattern", \
&EigenSparseSolver##dt##type##order::analyze_pattern) \
.def("factorize", &EigenSparseSolver##dt##type##order::factorize) \
.def("solve", \
&EigenSparseSolver##dt##type##order::solve<Eigen::VectorX##fd>) \
.def("solve_rf", \
&EigenSparseSolver##dt##type##order::solve_rf<Eigen::VectorX##fd, \
dt>) \
.def("info", &EigenSparseSolver##dt##type##order::info);

REGISTER_EIGEN_SOLVER(float32, LLT, AMD, f)
REGISTER_EIGEN_SOLVER(float32, LLT, COLAMD, f)
REGISTER_EIGEN_SOLVER(float32, LDLT, AMD, f)
REGISTER_EIGEN_SOLVER(float32, LDLT, COLAMD, f)
REGISTER_EIGEN_SOLVER(float32, LU, AMD, f)
REGISTER_EIGEN_SOLVER(float32, LU, COLAMD, f)
REGISTER_EIGEN_SOLVER(float64, LLT, AMD, d)
REGISTER_EIGEN_SOLVER(float64, LLT, COLAMD, d)
REGISTER_EIGEN_SOLVER(float64, LDLT, AMD, d)
REGISTER_EIGEN_SOLVER(float64, LDLT, COLAMD, d)
REGISTER_EIGEN_SOLVER(float64, LU, AMD, d)
REGISTER_EIGEN_SOLVER(float64, LU, COLAMD, d)

py::class_<CuSparseSolver, SparseSolver>(m, "CuSparseSolver")
.def("compute", &CuSparseSolver::compute)
.def("analyze_pattern", &CuSparseSolver::analyze_pattern)
.def("factorize", &CuSparseSolver::factorize)
.def("solve_rf", &CuSparseSolver::solve_rf)
.def("solve_cu", &CuSparseSolver::solve_cu)
.def("info", &CuSparseSolver::info);

m.def("make_sparse_solver", &make_sparse_solver);
m.def("make_cusparse_solver", &make_cusparse_solver);

Expand Down
11 changes: 7 additions & 4 deletions tests/python/test_sparse_linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
from tests import test_utils


@pytest.mark.parametrize("dtype", [ti.f32])
@pytest.mark.parametrize("dtype", [ti.f32, ti.f64])
@pytest.mark.parametrize("solver_type", ["LLT", "LDLT", "LU"])
@pytest.mark.parametrize("ordering", ["AMD", "COLAMD"])
@test_utils.test(arch=ti.cpu)
@test_utils.test(arch=ti.x64)
def test_sparse_LLT_solver(dtype, solver_type, ordering):
n = 10
A = np.random.rand(n, n)
A_psd = np.dot(A, A.transpose())
Abuilder = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=300)
b = ti.field(ti.f32, shape=n)
Abuilder = ti.linalg.SparseMatrixBuilder(n,
n,
max_num_triplets=100,
dtype=dtype)
b = ti.field(dtype=dtype, shape=n)

@ti.kernel
def fill(Abuilder: ti.types.sparse_matrix_builder(),
Expand Down

0 comments on commit 9da4488

Please sign in to comment.