diff --git a/cml/matrix/comparison.h b/cml/matrix/comparison.h new file mode 100644 index 0000000..a35210e --- /dev/null +++ b/cml/matrix/comparison.h @@ -0,0 +1,37 @@ +/* -*- C++ -*- ------------------------------------------------------------ + @@COPYRIGHT@@ + *-----------------------------------------------------------------------*/ +/** @file + */ + +#pragma once + +#ifndef cml_matrix_comparison_h +#define cml_matrix_comparison_h + +#include + +namespace cml { + +/** Returns true if the elements of @c left are all equal to the elements + * of @c right. + */ +template bool operator==( + const readable_matrix& left, const readable_matrix& right); + +/** Returns true if some element of @c left is not equal to the same element + * of @c right. + */ +template bool operator!=( + const readable_matrix& left, const readable_matrix& right); + +} // namespace cml + +#define __CML_MATRIX_COMPARISON_TPP +#include +#undef __CML_MATRIX_COMPARISON_TPP + +#endif + +// ------------------------------------------------------------------------- +// vim:ft=cpp:sw=2 diff --git a/cml/matrix/comparison.tpp b/cml/matrix/comparison.tpp new file mode 100644 index 0000000..fa1e549 --- /dev/null +++ b/cml/matrix/comparison.tpp @@ -0,0 +1,41 @@ +/* -*- C++ -*- ------------------------------------------------------------ + @@COPYRIGHT@@ + *-----------------------------------------------------------------------*/ +/** @file + */ + +#ifndef __CML_MATRIX_COMPARISON_TPP +#error "matrix/comparison.tpp not included correctly" +#endif + +#include + +namespace cml { + +template inline bool operator==( + const readable_matrix& left, const readable_matrix& right + ) +{ + /* Possibly equal only if the same dimensions: */ + if(left.size() != right.size()) return false; + for(int i = 0; i < left.rows(); i ++) { + for(int j = 0; j < left.cols(); j ++) { + /**/ if(left(i, j) < right(i, j)) return false; // Strictly less. + else if(right(i, j) < left(i, j)) return false; // Strictly greater. + else continue; // Possibly equal. + } + } + return true; +} + +template inline bool operator!=( + const readable_matrix& left, const readable_matrix& right + ) +{ + return !(left == right); +} + +} // namespace cml + +// ------------------------------------------------------------------------- +// vim:ft=cpp:sw=2 diff --git a/cml/vector/comparison.tpp b/cml/vector/comparison.tpp index e5bc3c3..d7dbc17 100644 --- a/cml/vector/comparison.tpp +++ b/cml/vector/comparison.tpp @@ -16,39 +16,42 @@ template inline bool operator<( const readable_vector& left, const readable_vector& right ) { - cml::check_same_size(left, right); - for(int i = 0; i < left.size(); i ++) { + int n = std::min(left.size(), right.size()); + for(int i = 0; i < n; i ++) { /**/ if(left[i] < right[i]) return true; // Strictly less. - else if(left[i] > right[i]) return false; // Strictly greater. - else continue; // Equal. + else if(right[i] < left[i]) return false; // Strictly greater. + else continue; // Possibly equal. } - /* Equal. */ - return false; + /* Equal only if the same length: */ + return left.size() < right.size(); } template inline bool operator>( const readable_vector& left, const readable_vector& right ) { - cml::check_same_size(left, right); - for(int i = 0; i < left.size(); i ++) { - /**/ if(left[i] > right[i]) return true; // Strictly greater. - else if(left[i] < right[i]) return false; // Strictly less. - else continue; // Equal. + int n = std::min(left.size(), right.size()); + for(int i = 0; i < n; i ++) { + /**/ if(left[i] < right[i]) return false; // Strictly less. + else if(right[i] < left[i]) return true; // Strictly greater. + else continue; // Possibly equal. } - /* Equal. */ - return false; + /* Equal only if the same length: */ + return left.size() > right.size(); } template inline bool operator==( const readable_vector& left, const readable_vector& right ) { - cml::check_same_size(left, right); + /* Possibly equal only if the same length: */ + if(left.size() != right.size()) return false; for(int i = 0; i < left.size(); i ++) { - if(!(left[i] == right[i])) return false; // Not equal. + /**/ if(left[i] < right[i]) return false; // Strictly less. + else if(right[i] < left[i]) return false; // Strictly greater. + else continue; // Possibly equal. } return true; } diff --git a/tests/matrix/CMakeLists.txt b/tests/matrix/CMakeLists.txt index 53fff8a..ccd455e 100644 --- a/tests/matrix/CMakeLists.txt +++ b/tests/matrix/CMakeLists.txt @@ -23,6 +23,7 @@ CML_ADD_TEST(rowcol1) CML_ADD_TEST(lu1) CML_ADD_TEST(determinant1) CML_ADD_TEST(matrix_hadamard_product1) +CML_ADD_TEST(matrix_comparison1) # -------------------------------------------------------------------------- # vim:ft=cmake diff --git a/tests/matrix/matrix_comparison1.cpp b/tests/matrix/matrix_comparison1.cpp new file mode 100644 index 0000000..a1276ee --- /dev/null +++ b/tests/matrix/matrix_comparison1.cpp @@ -0,0 +1,56 @@ +/* -*- C++ -*- ------------------------------------------------------------ + @@COPYRIGHT@@ + *-----------------------------------------------------------------------*/ +/** @file + */ + +// Make sure the main header compiles cleanly: +#include + +#include +#include + +// For Catch: +#include + +/* Testing headers: */ +#include "catch_runner.h" + +CATCH_TEST_CASE("equal1") +{ + auto M = cml::matrix33d( + 1., 2., 3., + 1., 4., 9., + 1., 16., 25. + ); + + M.transpose(); + auto expected = cml::matrix33d( + 1., 1., 1., + 2., 4., 16., + 3., 9., 25. + ); + + CATCH_CHECK(M == expected); +} + +CATCH_TEST_CASE("not_equal1") +{ + auto M = cml::matrix33d( + 1., 2., 3., + 1., 4., 9., + 1., 16., 25. + ); + + M.transpose(); + auto expected = cml::matrix33d( + 1., 1., 1., + 2., 4., 16., + 3., 9., 24. + ); + + CATCH_CHECK(M != expected); +} + +// ------------------------------------------------------------------------- +// vim:ft=cpp:sw=2