diff --git a/ml_dtypes/include/int4.h b/ml_dtypes/include/int4.h index c8ec5b06..419cc9f3 100644 --- a/ml_dtypes/include/int4.h +++ b/ml_dtypes/include/int4.h @@ -48,7 +48,7 @@ struct i4 { return std::is_signed::value ? i4(7) : i4(15); } - template >> + template explicit constexpr operator T() const { return static_cast(v); } diff --git a/ml_dtypes/tests/int4_test.cc b/ml_dtypes/tests/int4_test.cc index fc0304d5..d7cf538f 100644 --- a/ml_dtypes/tests/int4_test.cc +++ b/ml_dtypes/tests/int4_test.cc @@ -287,12 +287,26 @@ TYPED_TEST(Int4Test, ToString) { } } +struct CustomInt { + constexpr CustomInt() : x(0) {} + constexpr CustomInt(int x) : x(x) {} + // NOLINTNEXTLINE(google-explicit-constructor) + constexpr operator int() const { return x; } + constexpr bool operator==(const CustomInt& other) const { + return x == other.x; + } + + private: + int x; +}; + #define GEN_DEST_TYPES(Type) \ std::pair, std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ std::pair, std::pair, \ - std::pair, std::pair + std::pair, std::pair, \ + std::pair #define GEN_TYPE_PAIRS() GEN_DEST_TYPES(int4), GEN_DEST_TYPES(uint4)