diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h index 5357c39f08..9997e37543 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h @@ -11,26 +11,18 @@ namespace torchao::ops::linear_8bit_act_xbit_weight { torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( - int weight_nbit, + unsigned char weight_nbit, bool has_weight_zeros, bool has_bias, - int nr, - int kr, - int version = 1) { - TORCHAO_CHECK( - version >= 0 && version < 256, "version must be between 0 and 255"); - TORCHAO_CHECK( - weight_nbit >= 1 && weight_nbit < 256, - "weight_nbit must be between 1 and 255"); + unsigned short nr, + unsigned short kr, + unsigned char version = 1) { return torchao::ops::PackedWeightsHeader( torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal, - {((static_cast(version) << 8) | - static_cast(weight_nbit)), - ((static_cast(has_weight_zeros) << 8) | - static_cast(has_bias)), - static_cast(nr), - static_cast(kr), - 0, + {static_cast((version << 8) | weight_nbit), + static_cast((has_weight_zeros << 8) | has_bias), + nr, + kr, 0, 0, 0});