diff --git a/lib/column/column_gen.go b/lib/column/column_gen.go index 5bd1180a3b..943bb047a5 100644 --- a/lib/column/column_gen.go +++ b/lib/column/column_gen.go @@ -136,7 +136,7 @@ func (t Type) Column(name string, tz *time.Location) (Interface, error) { case "Point": return &Point{name: name}, nil case "String": - return &String{name: name, col: colStrProvider()}, nil + return &String{name: name, col: colStrProvider(name)}, nil case "Object('json')": return &JSONObject{name: name, root: true, tz: tz}, nil } diff --git a/lib/column/column_gen_option.go b/lib/column/column_gen_option.go index 03f93694cc..6a883527b4 100644 --- a/lib/column/column_gen_option.go +++ b/lib/column/column_gen_option.go @@ -20,13 +20,13 @@ package column import "github.com/ClickHouse/ch-go/proto" // ColStrProvider defines provider of proto.ColStr -type ColStrProvider func() proto.ColStr +type ColStrProvider func(name string) proto.ColStr // colStrProvider provide proto.ColStr for Column() when type is String var colStrProvider ColStrProvider = defaultColStrProvider // defaultColStrProvider defines sample provider for proto.ColStr -func defaultColStrProvider() proto.ColStr { +func defaultColStrProvider(string) proto.ColStr { return proto.ColStr{} } @@ -35,7 +35,7 @@ func defaultColStrProvider() proto.ColStr { // // It is more suitable for scenarios where a lot of data is written in batches func WithAllocBufferColStrProvider(cap int) { - colStrProvider = func() proto.ColStr { + colStrProvider = func(string) proto.ColStr { return proto.ColStr{Buf: make([]byte, 0, cap)} } } diff --git a/lib/column/enum.go b/lib/column/enum.go index 25d2e2199d..c0b8a4fbaa 100644 --- a/lib/column/enum.go +++ b/lib/column/enum.go @@ -46,6 +46,8 @@ func Enum(chType Type, name string) (Interface, error) { v := int8(indexes[i]) enum.iv[values[i]] = proto.Enum8(v) enum.vi[proto.Enum8(v)] = values[i] + + enum.enumValuesBitset[uint8(v)>>6] |= 1 << (v & 63) } return &enum, nil } @@ -54,12 +56,23 @@ func Enum(chType Type, name string) (Interface, error) { vi: make(map[proto.Enum16]string, len(values)), chType: chType, name: name, + // to be updated below, when ranging over all index/enum values + minEnum: math.MaxInt16, + maxEnum: math.MinInt16, } for i := range values { - enum.iv[values[i]] = proto.Enum16(indexes[i]) - enum.vi[proto.Enum16(indexes[i])] = values[i] + k := int16(indexes[i]) + enum.iv[values[i]] = proto.Enum16(k) + enum.vi[proto.Enum16(k)] = values[i] + if k < enum.minEnum { + enum.minEnum = k + } + if k > enum.maxEnum { + enum.maxEnum = k + } } + enum.continuous = (enum.maxEnum-enum.minEnum)+1 == int16(len(enum.vi)) return &enum, nil } diff --git a/lib/column/enum16.go b/lib/column/enum16.go index c394e7fff3..d3a15b80d6 100644 --- a/lib/column/enum16.go +++ b/lib/column/enum16.go @@ -31,6 +31,10 @@ type Enum16 struct { chType Type col proto.ColEnum16 name string + + continuous bool + minEnum int16 + maxEnum int16 } func (col *Enum16) Reset() { @@ -179,9 +183,17 @@ func (col *Enum16) Append(v any) (nulls []uint8, err error) { func (col *Enum16) AppendRow(elem any) error { switch elem := elem.(type) { case int16: - return col.AppendRow(int(elem)) + if col.continuous && elem >= col.minEnum && elem <= col.maxEnum { + col.col.Append(proto.Enum16(elem)) + } else { + return col.AppendRow(int(elem)) + } case *int16: - return col.AppendRow(int(*elem)) + if col.continuous && *elem >= col.minEnum && *elem <= col.maxEnum { + col.col.Append(proto.Enum16(*elem)) + } else { + return col.AppendRow(int(*elem)) + } case int: v := proto.Enum16(elem) _, ok := col.vi[v] diff --git a/lib/column/enum8.go b/lib/column/enum8.go index 4aee561ad7..2f29136559 100644 --- a/lib/column/enum8.go +++ b/lib/column/enum8.go @@ -31,6 +31,10 @@ type Enum8 struct { chType Type name string col proto.ColEnum8 + + // Encoding of the enums that have been specified by the user. + // Using this when appending rows, to validate the enum is valud. + enumValuesBitset [4]uint64 } func (col *Enum8) Reset() { @@ -183,27 +187,25 @@ func (col *Enum8) AppendRow(elem any) error { case *int8: return col.AppendRow(int(*elem)) case int: - v := proto.Enum8(elem) - _, ok := col.vi[v] - if !ok { + // Check if the enum value is defined + if col.enumValuesBitset[uint8(elem)>>6]&(1<<(elem&63)) == 0 { return &Error{ Err: fmt.Errorf("unknown element %v", elem), ColumnType: string(col.chType), } } - col.col.Append(v) + col.col.Append(proto.Enum8(elem)) case *int: switch { case elem != nil: - v := proto.Enum8(*elem) - _, ok := col.vi[v] - if !ok { + // Check if the enum value is defined + if col.enumValuesBitset[uint8(*elem)>>6]&(1<<(*elem&63)) == 0 { return &Error{ Err: fmt.Errorf("unknown element %v", *elem), ColumnType: string(col.chType), } } - col.col.Append(v) + col.col.Append(proto.Enum8(*elem)) default: col.col.Append(0) } diff --git a/lib/column/enum_test.go b/lib/column/enum_test.go index ef8d26df17..8a2965ffbe 100644 --- a/lib/column/enum_test.go +++ b/lib/column/enum_test.go @@ -1,6 +1,7 @@ package column import ( + "slices" "testing" "github.com/stretchr/testify/assert" @@ -155,3 +156,52 @@ func TestExtractEnumNamedValues(t *testing.T) { }) } } + +func TestEnumValuesBoundsChecks(t *testing.T) { + tests := []struct { + name string + enumType string + validEnums []int + }{ + { + name: "Simple enum range", + enumType: "Enum8('-2'=-2,'-1'=-1,'0'=0,'1'=1,'2'=2)", + validEnums: createValidEnumsRange(-2, 2), + }, + { + name: "Full enum range", + enumType: "Enum8('-128'=-128,'-127'=-127,'-126'=-126,'-125'=-125,'-124'=-124,'-123'=-123,'-122'=-122,'-121'=-121,'-120'=-120,'-119'=-119,'-118'=-118,'-117'=-117,'-116'=-116,'-115'=-115,'-114'=-114,'-113'=-113,'-112'=-112,'-111'=-111,'-110'=-110,'-109'=-109,'-108'=-108,'-107'=-107,'-106'=-106,'-105'=-105,'-104'=-104,'-103'=-103,'-102'=-102,'-101'=-101,'-100'=-100,'-99'=-99,'-98'=-98,'-97'=-97,'-96'=-96,'-95'=-95,'-94'=-94,'-93'=-93,'-92'=-92,'-91'=-91,'-90'=-90,'-89'=-89,'-88'=-88,'-87'=-87,'-86'=-86,'-85'=-85,'-84'=-84,'-83'=-83,'-82'=-82,'-81'=-81,'-80'=-80,'-79'=-79,'-78'=-78,'-77'=-77,'-76'=-76,'-75'=-75,'-74'=-74,'-73'=-73,'-72'=-72,'-71'=-71,'-70'=-70,'-69'=-69,'-68'=-68,'-67'=-67,'-66'=-66,'-65'=-65,'-64'=-64,'-63'=-63,'-62'=-62,'-61'=-61,'-60'=-60,'-59'=-59,'-58'=-58,'-57'=-57,'-56'=-56,'-55'=-55,'-54'=-54,'-53'=-53,'-52'=-52,'-51'=-51,'-50'=-50,'-49'=-49,'-48'=-48,'-47'=-47,'-46'=-46,'-45'=-45,'-44'=-44,'-43'=-43,'-42'=-42,'-41'=-41,'-40'=-40,'-39'=-39,'-38'=-38,'-37'=-37,'-36'=-36,'-35'=-35,'-34'=-34,'-33'=-33,'-32'=-32,'-31'=-31,'-30'=-30,'-29'=-29,'-28'=-28,'-27'=-27,'-26'=-26,'-25'=-25,'-24'=-24,'-23'=-23,'-22'=-22,'-21'=-21,'-20'=-20,'-19'=-19,'-18'=-18,'-17'=-17,'-16'=-16,'-15'=-15,'-14'=-14,'-13'=-13,'-12'=-12,'-11'=-11,'-10'=-10,'-9'=-9,'-8'=-8,'-7'=-7,'-6'=-6,'-5'=-5,'-4'=-4,'-3'=-3,'-2'=-2,'-1'=-1,'0'=0,'1'=1,'2'=2,'3'=3,'4'=4,'5'=5,'6'=6,'7'=7,'8'=8,'9'=9,'10'=10,'11'=11,'12'=12,'13'=13,'14'=14,'15'=15,'16'=16,'17'=17,'18'=18,'19'=19,'20'=20,'21'=21,'22'=22,'23'=23,'24'=24,'25'=25,'26'=26,'27'=27,'28'=28,'29'=29,'30'=30,'31'=31,'32'=32,'33'=33,'34'=34,'35'=35,'36'=36,'37'=37,'38'=38,'39'=39,'40'=40,'41'=41,'42'=42,'43'=43,'44'=44,'45'=45,'46'=46,'47'=47,'48'=48,'49'=49,'50'=50,'51'=51,'52'=52,'53'=53,'54'=54,'55'=55,'56'=56,'57'=57,'58'=58,'59'=59,'60'=60,'61'=61,'62'=62,'63'=63,'64'=64,'65'=65,'66'=66,'67'=67,'68'=68,'69'=69,'70'=70,'71'=71,'72'=72,'73'=73,'74'=74,'75'=75,'76'=76,'77'=77,'78'=78,'79'=79,'80'=80,'81'=81,'82'=82,'83'=83,'84'=84,'85'=85,'86'=86,'87'=87,'88'=88,'89'=89,'90'=90,'91'=91,'92'=92,'93'=93,'94'=94,'95'=95,'96'=96,'97'=97,'98'=98,'99'=99,'100'=100,'101'=101,'102'=102,'103'=103,'104'=104,'105'=105,'106'=106,'107'=107,'108'=108,'109'=109,'110'=110,'111'=111,'112'=112,'113'=113,'114'=114,'115'=115,'116'=116,'117'=117,'118'=118,'119'=119,'120'=120,'121'=121,'122'=122,'123'=123,'124'=124,'125'=125,'126'=126,'127'=127)", + validEnums: createValidEnumsRange(-128, 127), + }, + { + name: "Enum range with gaps", + enumType: "Enum8('-10'=-10,'-5'=-5,'0'=0,'1'=1,'5'=5,'10'=10)", + validEnums: []int{-10, -5, 0, 1, 5, 10}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e, err := Enum(Type(tt.enumType), tt.name) + assert.NoError(t, err) + + // Try appending the full enum8 range. If the value is in the validEnums slice it should not error + for i := -128; i < 128; i++ { + valid := e.AppendRow(i) + + if slices.Contains(tt.validEnums, i) { + assert.NoError(t, valid) + } else { + assert.Error(t, valid) + } + } + }) + } +} + +func createValidEnumsRange(min, max int) []int { + resultRange := make([]int, 0, max-min+1) + for i := min; i <= max; i++ { + resultRange = append(resultRange, i) + } + return resultRange +}