diff --git a/types/field_type.go b/types/field_type.go index ae4ca8bc1d614..767d297f634ff 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -72,6 +72,7 @@ func AggFieldType(tps []*FieldType) *FieldType { } mtp := MergeFieldType(currType.Tp, t.Tp) currType.Tp = mtp + currType.Flag = mergeTypeFlag(currType.Flag, t.Flag) } return &currType @@ -308,6 +309,13 @@ func MergeFieldType(a byte, b byte) byte { return fieldTypeMergeRules[ia][ib] } +// mergeTypeFlag merges two MySQL type flag to a new one +// currently only NotNullFlag is checked +// todo more flag need to be checked, for example: UnsignedFlag +func mergeTypeFlag(a, b uint) uint { + return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) +} + func getFieldTypeIndex(tp byte) int { itp := int(tp) if itp < fieldTypeTearFrom { diff --git a/types/field_type_test.go b/types/field_type_test.go index 31c5fd97a2e4d..9b50d65e9a7d9 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -300,6 +300,32 @@ func (s *testFieldTypeSuite) TestAggFieldType(c *C) { } } } +func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) { + types := []*FieldType{ + NewFieldType(mysql.TypeLonglong), + NewFieldType(mysql.TypeLonglong), + } + + aggTp := AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, uint(0)) + + types[0].Flag = mysql.NotNullFlag + aggTp = AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, uint(0)) + + types[0].Flag = 0 + types[1].Flag = mysql.NotNullFlag + aggTp = AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, uint(0)) + + types[0].Flag = mysql.NotNullFlag + aggTp = AggFieldType(types) + c.Assert(aggTp.Tp, Equals, mysql.TypeLonglong) + c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag) +} func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { defer testleak.AfterTest(c)()