Skip to content

Commit

Permalink
fixes backward compatibility for iteration complete event
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcelKoch committed May 4, 2023
1 parent 1f5d219 commit 7d3b7b6
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 48 deletions.
22 changes: 22 additions & 0 deletions core/log/profiler_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,28 @@ void ProfilerHook::on_iteration_complete(
}


void ProfilerHook::on_iteration_complete(const LinOp* solver,
const size_type& num_iterations,
const LinOp* residual,
const LinOp* solution,
const LinOp* residual_norm) const
{
on_iteration_complete(solver, nullptr, solution, num_iterations, residual,
residual_norm, nullptr, nullptr, false);
}


void ProfilerHook::on_iteration_complete(
const LinOp* solver, const size_type& num_iterations, const LinOp* residual,
const LinOp* solution, const LinOp* residual_norm,
const LinOp* implicit_sq_residual_norm) const
{
on_iteration_complete(solver, nullptr, solution, num_iterations, residual,
residual_norm, implicit_sq_residual_norm, nullptr,
false);
}


bool ProfilerHook::needs_propagation() const { return true; }


Expand Down
24 changes: 24 additions & 0 deletions core/log/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,30 @@ void Stream<ValueType>::on_iteration_complete(
}


template <typename ValueType>
void Stream<ValueType>::on_iteration_complete(const LinOp* solver,
const size_type& num_iterations,
const LinOp* residual,
const LinOp* solution,
const LinOp* residual_norm) const
{
on_iteration_complete(solver, nullptr, solution, num_iterations, residual,
residual_norm, nullptr, nullptr, false);
}


template <typename ValueType>
void Stream<ValueType>::on_iteration_complete(
const LinOp* solver, const size_type& num_iterations, const LinOp* residual,
const LinOp* solution, const LinOp* residual_norm,
const LinOp* implicit_sq_residual_norm) const
{
on_iteration_complete(solver, nullptr, solution, num_iterations, residual,
residual_norm, implicit_sq_residual_norm, nullptr,
false);
}


#define GKO_DECLARE_STREAM(_type) class Stream<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_STREAM);

Expand Down
3 changes: 2 additions & 1 deletion core/test/log/logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ struct DummyLoggedClass : gko::log::EnableLogging<DummyLoggedClass> {
void apply()
{
this->log<gko::log::Logger::iteration_complete>(
nullptr, num_iters, nullptr, nullptr, nullptr);
nullptr, nullptr, nullptr, num_iters, nullptr, nullptr, nullptr,
nullptr, false);
}

std::shared_ptr<const gko::Executor> get_executor() const { return exec; }
Expand Down
111 changes: 77 additions & 34 deletions include/ginkgo/core/log/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,27 +422,19 @@ public: \
one_changed, all_converged);
}

/**
* Register the `iteration_complete` event which logs every completed
* iterations.
*
* @param solver the solver executing the iteration
* @param b the right-hand-side vector
* @param x the solution vector
* @param it the current iteration count
* @param r the residual (optional)
* @param tau the implicit residual norm squared (optional)
* @param implicit_tau_sq the residual norm (optional)
* @param status the stopping status of the right hand sides (optional)
* @param stopped whether all right hand sides have stopped (invalid if
* status is not provided)
*/
GKO_LOGGER_REGISTER_EVENT(21, iteration_complete, const LinOp* solver,
const LinOp* b, const LinOp* x,
const size_type& it, const LinOp* r,
const LinOp* tau, const LinOp* implicit_tau_sq,
const array<stopping_status>* status,
bool stopped)
public:
static constexpr size_type iteration_complete{21};
static constexpr mask_type iteration_complete_mask{mask_type{1} << 21};

template <size_type Event, typename... Params>
std::enable_if_t<Event == 21 && (21 < event_count_max)> on(
Params&&... params) const
{
if (enabled_events_ & (mask_type{1} << 21)) {
this->on_iteration_complete(std::forward<Params>(params)...);
}
}

protected:
/**
* Register the `iteration_complete` event which logs every completed
Expand All @@ -453,20 +445,16 @@ public: \
* @param x the solution vector (optional)
* @param tau the residual norm (optional)
*
* @note The on_iteration_complete function that this macro declares is
* deprecated. Please use the one with the additional implicit_tau_sq
* parameter as below.
* @warning This on_iteration_complete function that this macro declares is
* deprecated. Please use the version with the stopping information.
*/
[[deprecated(
"Please use the version with the additional implicit_tau_sq, status "
"and stopped parameter.")]] virtual void
"Please use the version with the additional stopping "
"information.")]] virtual void
on_iteration_complete(const LinOp* solver, const size_type& it,
const LinOp* r, const LinOp* x = nullptr,
const LinOp* tau = nullptr) const
{
this->on_iteration_complete(solver, nullptr, x, it, r, tau, nullptr,
nullptr, false);
}
{}

/**
* Register the `iteration_complete` event which logs every completed
Expand All @@ -477,16 +465,71 @@ public: \
* @param x the solution vector (optional)
* @param tau the residual norm (optional)
* @param implicit_tau_sq the implicit residual norm squared (optional)
*
* @warning This on_iteration_complete function that this macro declares is
* deprecated. Please use the version with the stopping information.
*/
[[deprecated(
"Please use the version with the additional status and stopped "
"parameter.")]] virtual void
"Please use the version with the additional stopping "
"information.")]] virtual void
on_iteration_complete(const LinOp* solver, const size_type& it,
const LinOp* r, const LinOp* x, const LinOp* tau,
const LinOp* implicit_tau_sq) const
{
this->on_iteration_complete(solver, nullptr, x, it, r, tau,
implicit_tau_sq, nullptr, false);
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 5211, 4973, 4974)
#endif
this->on_iteration_complete(solver, it, r, x, tau);
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
#ifdef _MSC_VER
#pragma warning(pop)
#endif
}

/**
* Register the `iteration_complete` event which logs every completed
* iterations.
*
* @param solver the solver executing the iteration
* @param b the right-hand-side vector
* @param x the solution vector
* @param it the current iteration count
* @param r the residual (optional)
* @param tau the implicit residual norm squared (optional)
* @param implicit_tau_sq the residual norm (optional)
* @param status the stopping status of the right hand sides (optional)
* @param stopped whether all right hand sides have stopped (invalid if
* status is not provided)
*/
virtual void on_iteration_complete(const LinOp* solver, const LinOp* b,
const LinOp* x, const size_type& it,
const LinOp* r, const LinOp* tau,
const LinOp* implicit_tau_sq,
const array<stopping_status>* status,
bool stopped) const
{
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 5211, 4973, 4974)
#endif
this->on_iteration_complete(solver, it, r, x, tau);
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
#ifdef _MSC_VER
#pragma warning(pop)
#endif
}

public:
Expand Down
17 changes: 11 additions & 6 deletions include/ginkgo/core/log/papi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,17 @@ class Papi : public Logger {
const array<stopping_status>* status,
bool stopped) const override;

void on_iteration_complete(
const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution = nullptr,
const LinOp* residual_norm = nullptr) const override;

void on_iteration_complete(
[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm) const override;

[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(
const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm,
Expand Down
16 changes: 16 additions & 0 deletions include/ginkgo/core/log/profiler_hook.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ class ProfilerHook : public Logger {
const LinOp* implicit_sq_residual_norm,
const array<stopping_status>* status, bool stopped) const override;

[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm) const override;

[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(
const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm,
const LinOp* implicit_sq_residual_norm) const override;

bool needs_propagation() const override;

/**
Expand Down
21 changes: 14 additions & 7 deletions include/ginkgo/core/log/record.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ struct iteration_complete_data {
: num_iterations{num_iterations}, all_stopped(all_stopped)
{
this->solver = solver->clone();
this->right_hand_side = right_hand_side->clone();
this->solution = solution->clone();
if (right_hand_side != nullptr) {
this->right_hand_side = right_hand_side->clone();
}
if (residual != nullptr) {
this->residual = residual->clone();
}
Expand Down Expand Up @@ -400,12 +402,17 @@ class Record : public Logger {
const LinOp* residual_norm, const LinOp* implicit_resnorm_sq,
const array<stopping_status>* status, bool stopped) const override;

void on_iteration_complete(
const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution = nullptr,
const LinOp* residual_norm = nullptr) const override;

void on_iteration_complete(
[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm) const override;

[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(
const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm,
Expand Down
16 changes: 16 additions & 0 deletions include/ginkgo/core/log/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,22 @@ class Stream : public Logger {
const array<stopping_status>* status,
bool stopped) const override;

[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm) const override;

[[deprecated(
"Please use the version with the additional stopping "
"information.")]] void
on_iteration_complete(
const LinOp* solver, const size_type& num_iterations,
const LinOp* residual, const LinOp* solution,
const LinOp* residual_norm,
const LinOp* implicit_sq_residual_norm) const override;

/**
* Creates a Stream logger. This dynamically allocates the memory,
* constructs the object and returns an std::unique_ptr to this object.
Expand Down

0 comments on commit 7d3b7b6

Please sign in to comment.