Skip to content

Commit

Permalink
fix csr strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 16, 2019
1 parent 2342c02 commit 28f9fa6
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions include/ginkgo/core/matrix/csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
* strategy_type is to decide how to set the csr algorithm.
*
* The practical strategy method should inherit strategy_type and implement
* its `process`, `calc_size` function and the corresponding device kernel.
* its `process`, `clac_size` function and the corresponding device kernel.
*/
class strategy_type {
friend class automatical;
Expand Down Expand Up @@ -159,7 +159,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
*
* @return the size of srow
*/
virtual int64_t calc_size(const int64_t nnz) = 0;
virtual int64_t clac_size(const int64_t nnz) = 0;

protected:
void set_name(std::string name) { name_ = name; }
Expand All @@ -183,7 +183,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
Array<index_type> *mtx_srow) override
{}

int64_t calc_size(const int64_t nnz) override { return 0; }
int64_t clac_size(const int64_t nnz) override { return 0; }
};

/**
Expand All @@ -202,7 +202,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
Array<index_type> *mtx_srow) override
{}

int64_t calc_size(const int64_t nnz) override { return 0; }
int64_t clac_size(const int64_t nnz) override { return 0; }
};

/**
Expand All @@ -222,7 +222,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
Array<index_type> *mtx_srow) override
{}

int64_t calc_size(const int64_t nnz) override { return 0; }
int64_t clac_size(const int64_t nnz) override { return 0; }
};

/**
Expand All @@ -241,7 +241,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
Array<index_type> *mtx_srow) override
{}

int64_t calc_size(const int64_t nnz) override { return 0; }
int64_t clac_size(const int64_t nnz) override { return 0; }
};

/**
Expand Down Expand Up @@ -345,7 +345,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
}
}

int64_t calc_size(const int64_t nnz) override
int64_t clac_size(const int64_t nnz) override
{
if (warp_size_ > 0) {
int multiple = 8;
Expand Down Expand Up @@ -446,7 +446,8 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
}
const auto num_rows = mtx_row_ptrs.get_num_elems() - 1;
if (row_ptrs[num_rows] > index_type(1e6)) {
load_balance actual_strategy(nwarps_);
load_balance actual_strategy(nwarps_, warp_size_,
cuda_strategy_);
if (is_mtx_on_host) {
actual_strategy.process(mtx_row_ptrs, mtx_srow);
} else {
Expand All @@ -459,7 +460,8 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
maxnum = max(maxnum, row_ptrs[i] - row_ptrs[i - 1]);
}
if (maxnum > 64) {
load_balance actual_strategy(nwarps_);
load_balance actual_strategy(nwarps_, warp_size_,
cuda_strategy_);
if (is_mtx_on_host) {
actual_strategy.process(mtx_row_ptrs, mtx_srow);
} else {
Expand All @@ -478,11 +480,11 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
}
}

int64_t calc_size(const int64_t nnz) override
int64_t clac_size(const int64_t nnz) override
{
return std::make_shared<load_balance>(nwarps_, warp_size_,
cuda_strategy_)
->calc_size(nnz);
->clac_size(nnz);
}

private:
Expand Down Expand Up @@ -710,7 +712,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
col_idxs_(exec, num_nonzeros),
// avoid allocation for empty matrix
row_ptrs_(exec, size[0] + (size[0] > 0)),
srow_(exec, strategy->calc_size(num_nonzeros)),
srow_(exec, strategy->clac_size(num_nonzeros)),
strategy_(std::move(strategy))
{}

Expand Down Expand Up @@ -762,7 +764,7 @@ class Csr : public EnableLinOp<Csr<ValueType, IndexType>>,
*/
void make_srow()
{
srow_.resize_and_reset(strategy_->calc_size(values_.get_num_elems()));
srow_.resize_and_reset(strategy_->clac_size(values_.get_num_elems()));
strategy_->process(row_ptrs_, &srow_);
}

Expand Down

0 comments on commit 28f9fa6

Please sign in to comment.