Skip to content

Commit

Permalink
[Numerics] Use Eigen instead of internal LAPACK for DenseMatrix opera…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
speth committed May 15, 2016
1 parent bf67cce commit b49c2e4
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 60 deletions.
5 changes: 5 additions & 0 deletions include/cantera/numerics/eigen_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
#else
#include "cantera/ext/Eigen/Dense"
#endif

namespace Cantera {
typedef Eigen::Map<Eigen::MatrixXd> MappedMatrix;
typedef Eigen::Map<Eigen::VectorXd> MappedVector;
}
164 changes: 104 additions & 60 deletions src/numerics/DenseMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

// Copyright 2001 California Institute of Technology

#include "cantera/numerics/ctlapack.h"
#include "cantera/numerics/DenseMatrix.h"
#include "cantera/base/stringUtils.h"
#if CT_USE_LAPACK
#include "cantera/numerics/ctlapack.h"
#else
#include "cantera/numerics/eigen_dense.h"
#endif

namespace Cantera
{
Expand Down Expand Up @@ -85,10 +89,17 @@ const doublereal* const* DenseMatrix::const_colPts() const

void DenseMatrix::mult(const double* b, double* prod) const
{
#if CT_USE_LAPACK
ct_dgemv(ctlapack::ColMajor, ctlapack::NoTranspose,
static_cast<int>(nRows()),
static_cast<int>(nColumns()), 1.0, ptrColumn(0),
static_cast<int>(nRows()), b, 1, 0.0, prod, 1);
#else
MappedMatrix mat(const_cast<double*>(m_data.data()), nRows(), nColumns());
MappedVector bm(const_cast<double*>(b), nColumns());
MappedVector pm(prod, nRows());
pm = mat * bm;
#endif
}

void DenseMatrix::mult(const DenseMatrix& B, DenseMatrix& prod) const
Expand Down Expand Up @@ -135,44 +146,63 @@ int solve(DenseMatrix& A, double* b, size_t nrhs, size_t ldb)
}
throw CanteraError("solve(DenseMatrix& A, double* b)", "Can only solve a square matrix");
}
int info = 0;
ct_dgetrf(A.nRows(), A.nColumns(), A.ptrColumn(0),
A.nRows(), &A.ipiv()[0], info);
if (info > 0) {
if (A.m_printLevel) {
writelogf("solve(DenseMatrix& A, double* b): DGETRF returned INFO = %d U(i,i) is exactly zero. The factorization has"
" been completed, but the factor U is exactly singular, and division by zero will occur if "
"it is used to solve a system of equations.\n", info);
}
if (!A.m_useReturnErrorCode) {
throw CanteraError("solve(DenseMatrix& A, double* b)",
"DGETRF returned INFO = {}. U(i,i) is exactly zero. The factorization has"
" been completed, but the factor U is exactly singular, and division by zero will occur if "
"it is used to solve a system of equations.", info);
}
return info;
} else if (info < 0) {
if (A.m_printLevel) {
writelogf("solve(DenseMatrix& A, double* b): DGETRF returned INFO = %d. The argument i has an illegal value\n", info);
}

throw CanteraError("solve(DenseMatrix& A, double* b)",
"DGETRF returned INFO = {}. The argument i has an illegal value", info);
}

int info = 0;
if (ldb == 0) {
ldb = A.nColumns();
}
ct_dgetrs(ctlapack::NoTranspose, A.nRows(), nrhs, A.ptrColumn(0),
A.nRows(), &A.ipiv()[0], b, ldb, info);
if (info != 0) {
if (A.m_printLevel) {
writelogf("solve(DenseMatrix& A, double* b): DGETRS returned INFO = %d\n", info);
#if CT_USE_LAPACK
ct_dgetrf(A.nRows(), A.nColumns(), A.ptrColumn(0),
A.nRows(), &A.ipiv()[0], info);
if (info > 0) {
if (A.m_printLevel) {
writelogf("solve(DenseMatrix& A, double* b): DGETRF returned INFO = %d U(i,i) is exactly zero. The factorization has"
" been completed, but the factor U is exactly singular, and division by zero will occur if "
"it is used to solve a system of equations.\n", info);
}
if (!A.m_useReturnErrorCode) {
throw CanteraError("solve(DenseMatrix& A, double* b)",
"DGETRF returned INFO = {}. U(i,i) is exactly zero. The factorization has"
" been completed, but the factor U is exactly singular, and division by zero will occur if "
"it is used to solve a system of equations.", info);
}
return info;
} else if (info < 0) {
if (A.m_printLevel) {
writelogf("solve(DenseMatrix& A, double* b): DGETRF returned INFO = %d. The argument i has an illegal value\n", info);
}

throw CanteraError("solve(DenseMatrix& A, double* b)",
"DGETRF returned INFO = {}. The argument i has an illegal value", info);
}
if (info < 0 || !A.m_useReturnErrorCode) {
throw CanteraError("solve(DenseMatrix& A, double* b)", "DGETRS returned INFO = {}", info);

ct_dgetrs(ctlapack::NoTranspose, A.nRows(), nrhs, A.ptrColumn(0),
A.nRows(), &A.ipiv()[0], b, ldb, info);
if (info != 0) {
if (A.m_printLevel) {
writelogf("solve(DenseMatrix& A, double* b): DGETRS returned INFO = %d\n", info);
}
if (info < 0 || !A.m_useReturnErrorCode) {
throw CanteraError("solve(DenseMatrix& A, double* b)", "DGETRS returned INFO = {}", info);
}
}
}
#else
MappedMatrix Am(&A(0,0), A.nRows(), A.nColumns());
#ifdef NDEBUG
auto lu = Am.partialPivLu();
#else
auto lu = Am.fullPivLu();
if (lu.nonzeroPivots() < static_cast<long int>(A.nColumns())) {
throw CanteraError("solve(DenseMatrix& A, double* b)",
"Matrix appears to be rank-deficient: non-zero pivots = {}; columns = {}",
lu.nonzeroPivots(), A.nColumns());
}
#endif
for (size_t i = 0; i < nrhs; i++) {
MappedVector bm(b + ldb*i, A.nColumns());
bm = lu.solve(bm);
}
#endif
return info;
}

Expand All @@ -183,46 +213,60 @@ int solve(DenseMatrix& A, DenseMatrix& b)

void multiply(const DenseMatrix& A, const double* const b, double* const prod)
{
ct_dgemv(ctlapack::ColMajor, ctlapack::NoTranspose,
static_cast<int>(A.nRows()), static_cast<int>(A.nColumns()), 1.0,
A.ptrColumn(0), static_cast<int>(A.nRows()), b, 1, 0.0, prod, 1);
A.mult(b, prod);
}

void increment(const DenseMatrix& A, const double* b, double* prod)
{
ct_dgemv(ctlapack::ColMajor, ctlapack::NoTranspose,
static_cast<int>(A.nRows()), static_cast<int>(A.nColumns()), 1.0,
A.ptrColumn(0), static_cast<int>(A.nRows()), b, 1, 1.0, prod, 1);
#if CT_USE_LAPACK
ct_dgemv(ctlapack::ColMajor, ctlapack::NoTranspose,
static_cast<int>(A.nRows()), static_cast<int>(A.nColumns()), 1.0,
A.ptrColumn(0), static_cast<int>(A.nRows()), b, 1, 1.0, prod, 1);
#else
MappedMatrix Am(&const_cast<DenseMatrix&>(A)(0,0), A.nRows(), A.nColumns());
MappedVector bm(const_cast<double*>(b), A.nColumns());
MappedVector pm(prod, A.nRows());
pm += Am * bm;
#endif
}

int invert(DenseMatrix& A, size_t nn)
{
integer n = static_cast<int>(nn != npos ? nn : A.nRows());
int info=0;
ct_dgetrf(n, n, A.ptrColumn(0), static_cast<int>(A.nRows()),
&A.ipiv()[0], info);
if (info != 0) {
if (A.m_printLevel) {
writelogf("invert(DenseMatrix& A, int nn): DGETRS returned INFO = %d\n", info);
}
if (! A.m_useReturnErrorCode) {
throw CanteraError("invert(DenseMatrix& A, int nn)", "DGETRS returned INFO = {}", info);
#if CT_USE_LAPACK
integer n = static_cast<int>(nn != npos ? nn : A.nRows());
ct_dgetrf(n, n, A.ptrColumn(0), static_cast<int>(A.nRows()),
&A.ipiv()[0], info);
if (info != 0) {
if (A.m_printLevel) {
writelogf("invert(DenseMatrix& A, int nn): DGETRS returned INFO = %d\n", info);
}
if (! A.m_useReturnErrorCode) {
throw CanteraError("invert(DenseMatrix& A, int nn)", "DGETRS returned INFO = {}", info);
}
return info;
}
return info;
}

vector_fp work(n);
integer lwork = static_cast<int>(work.size());
ct_dgetri(n, A.ptrColumn(0), static_cast<int>(A.nRows()),
&A.ipiv()[0], &work[0], lwork, info);
if (info != 0) {
if (A.m_printLevel) {
writelogf("invert(DenseMatrix& A, int nn): DGETRS returned INFO = %d\n", info);
vector_fp work(n);
integer lwork = static_cast<int>(work.size());
ct_dgetri(n, A.ptrColumn(0), static_cast<int>(A.nRows()),
&A.ipiv()[0], &work[0], lwork, info);
if (info != 0) {
if (A.m_printLevel) {
writelogf("invert(DenseMatrix& A, int nn): DGETRS returned INFO = %d\n", info);
}
if (! A.m_useReturnErrorCode) {
throw CanteraError("invert(DenseMatrix& A, int nn)", "DGETRI returned INFO={}", info);
}
}
if (! A.m_useReturnErrorCode) {
throw CanteraError("invert(DenseMatrix& A, int nn)", "DGETRI returned INFO={}", info);
#else
MappedMatrix Am(&A(0,0), A.nRows(), A.nColumns());
if (nn == npos) {
Am = Am.inverse();
} else {
Am.topLeftCorner(nn, nn) = Am.topLeftCorner(nn, nn).inverse();
}
}
#endif
return info;
}

Expand Down

0 comments on commit b49c2e4

Please sign in to comment.