diff --git a/cmd/schemagen/keyspace.tmpl b/cmd/schemagen/keyspace.tmpl index ed48c96..cddb23d 100644 --- a/cmd/schemagen/keyspace.tmpl +++ b/cmd/schemagen/keyspace.tmpl @@ -2,7 +2,12 @@ package {{.PackageName}} -import "github.com/scylladb/gocqlx/v2/table" +import ( + "github.com/scylladb/gocqlx/v2/table" + {{- range .Imports}} + "{{.}}" + {{- end}} +) // Table models. var ( @@ -30,3 +35,29 @@ var ( {{end}} {{end}} ) + +{{with .UserTypes}} +{{range .}} +{{- $type_name := .Name | camelize}} +{{- $field_types := .FieldTypes}} +type {{$type_name}}UserType struct { +{{- range $index, $element := .FieldNames}} + {{- $type := index $field_types $index}} + {{. | camelize}} {{typeToString $type | mapScyllaToGoType}} +{{- end}} +} +{{- end}} +{{- end}} + +{{with .Tables}} +{{range .}} +{{- $model_name := .Name | camelize}} +type {{$model_name}}Struct struct { +{{- range .Columns}} + {{- if not (eq .Validator "empty") }} + {{.Name | camelize}} {{.Validator | mapScyllaToGoType}} + {{- end}} +{{- end}} +} +{{- end}} +{{- end}} diff --git a/cmd/schemagen/map_types.go b/cmd/schemagen/map_types.go new file mode 100644 index 0000000..4f59554 --- /dev/null +++ b/cmd/schemagen/map_types.go @@ -0,0 +1,105 @@ +package main + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/gocql/gocql" +) + +var types = map[string]string{ + "ascii": "string", + "bigint": "int64", + "blob": "[]byte", + "boolean": "bool", + "counter": "int", + "date": "time.Time", + "decimal": "inf.Dec", + "double": "float64", + "duration": "gocql.Duration", + "float": "float32", + "inet": "string", + "int": "int32", + "smallint": "int16", + "text": "string", + "time": "time.Duration", + "timestamp": "time.Time", + "timeuuid": "[16]byte", + "tinyint": "int8", + "uuid": "[16]byte", + "varchar": "string", + "varint": "int64", +} + +func mapScyllaToGoType(s string) string { + frozenRegex := regexp.MustCompile(`frozen<([a-z]*)>`) + match := frozenRegex.FindAllStringSubmatch(s, -1) + if match != nil { + s = match[0][1] + } + + mapRegex := regexp.MustCompile(`map<([a-z]*), ([a-z]*)>`) + setRegex := regexp.MustCompile(`set<([a-z]*)>`) + listRegex := regexp.MustCompile(`list<([a-z]*)>`) + tupleRegex := regexp.MustCompile(`tuple<(?:([a-z]*),? ?)*>`) + match = mapRegex.FindAllStringSubmatch(s, -1) + if match != nil { + key := match[0][1] + value := match[0][2] + + return "map[" + types[key] + "]" + types[value] + } + + match = setRegex.FindAllStringSubmatch(s, -1) + if match != nil { + key := match[0][1] + + return "[]" + types[key] + } + + match = listRegex.FindAllStringSubmatch(s, -1) + if match != nil { + key := match[0][1] + + return "[]" + types[key] + } + + match = tupleRegex.FindAllStringSubmatch(s, -1) + if match != nil { + tuple := match[0][0] + subStr := tuple[6 : len(tuple)-1] + types := strings.Split(subStr, ", ") + + typeStr := "struct {\n" + for i, t := range types { + typeStr = typeStr + "\t\tField" + strconv.Itoa(i+1) + " " + mapScyllaToGoType(t) + "\n" + } + typeStr = typeStr + "\t}" + + return typeStr + } + + t, exists := types[s] + if exists { + return t + } + + return camelize(s) + "UserType" +} + +func typeToString(t interface{}) string { + tType := fmt.Sprintf("%T", t) + switch tType { + case "gocql.NativeType": + return t.(gocql.NativeType).String() + case "gocql.CollectionType": + collectionType := t.(gocql.CollectionType).String() + collectionType = strings.Replace(collectionType, "(", "<", -1) + collectionType = strings.Replace(collectionType, ")", ">", -1) + return collectionType + default: + panic(fmt.Sprintf("Did not expect %v type in user defined type", tType)) + } +} diff --git a/cmd/schemagen/map_types_test.go b/cmd/schemagen/map_types_test.go new file mode 100644 index 0000000..d217418 --- /dev/null +++ b/cmd/schemagen/map_types_test.go @@ -0,0 +1,49 @@ +// Copyright (C) 2017 ScyllaDB +// Use of this source code is governed by a ALv2-style +// license that can be found in the LICENSE file. + +package main + +import ( + "testing" +) + +func TestMapScyllaToGoType(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"ascii", "string"}, + {"bigint", "int64"}, + {"blob", "[]byte"}, + {"boolean", "bool"}, + {"counter", "int"}, + {"date", "time.Time"}, + {"decimal", "inf.Dec"}, + {"double", "float64"}, + {"duration", "gocql.Duration"}, + {"float", "float32"}, + {"inet", "string"}, + {"int", "int32"}, + {"smallint", "int16"}, + {"text", "string"}, + {"time", "time.Duration"}, + {"timestamp", "time.Time"}, + {"timeuuid", "[16]byte"}, + {"tinyint", "int8"}, + {"uuid", "[16]byte"}, + {"varchar", "string"}, + {"varint", "int64"}, + {"map", "map[int32]string"}, + {"list", "[]int32"}, + {"set", "[]int32"}, + {"tuple", "struct {\n\t\tField1 bool\n\t\tField2 int32\n\t\tField3 int16\n\t}"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + if got := mapScyllaToGoType(tt.input); got != tt.want { + t.Errorf("mapScyllaToGoType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/schemagen/schemagen.go b/cmd/schemagen/schemagen.go index d54b747..99ae440 100644 --- a/cmd/schemagen/schemagen.go +++ b/cmd/schemagen/schemagen.go @@ -72,16 +72,35 @@ func renderTemplate(md *gocql.KeyspaceMetadata) ([]byte, error) { t, err := template. New("keyspace.tmpl"). Funcs(template.FuncMap{"camelize": camelize}). + Funcs(template.FuncMap{"mapScyllaToGoType": mapScyllaToGoType}). + Funcs(template.FuncMap{"typeToString": typeToString}). Parse(keyspaceTmpl) if err != nil { log.Fatalln("unable to parse models template:", err) } + imports := make([]string, 0) + for _, t := range md.Tables { + for _, c := range t.Columns { + if (c.Validator == "timestamp" || c.Validator == "date" || c.Validator == "duration" || c.Validator == "time") && !existsInSlice(imports, "time") { + imports = append(imports, "time") + } + if c.Validator == "decimal" && !existsInSlice(imports, "gopkg.in/inf.v0") { + imports = append(imports, "gopkg.in/inf.v0") + } + if c.Validator == "duration" && !existsInSlice(imports, "github.com/gocql/gocql") { + imports = append(imports, "github.com/gocql/gocql") + } + } + } + buf := &bytes.Buffer{} data := map[string]interface{}{ "PackageName": *flagPkgname, "Tables": md.Tables, + "UserTypes": md.UserTypes, + "Imports": imports, } if err = t.Execute(buf, data); err != nil { @@ -98,3 +117,13 @@ func createSession() (gocqlx.Session, error) { func clusterHosts() []string { return strings.Split(*flagCluster, ",") } + +func existsInSlice(s []string, v string) bool { + for _, i := range s { + if v == i { + return true + } + } + + return false +} diff --git a/cmd/schemagen/schemagen_test.go b/cmd/schemagen/schemagen_test.go index 1ca13af..382d21f 100644 --- a/cmd/schemagen/schemagen_test.go +++ b/cmd/schemagen/schemagen_test.go @@ -56,10 +56,17 @@ func createTestSchema(t *testing.T) { t.Fatal("create table:", err) } + err = session.ExecStmt(`CREATE TYPE IF NOT EXISTS schemagen.album ( + name text, + songwriters set,)`) + if err != nil { + t.Fatal("create type:", err) + } + err = session.ExecStmt(`CREATE TABLE IF NOT EXISTS schemagen.playlists ( id uuid, title text, - album text, + album frozen, artist text, song_id uuid, PRIMARY KEY (id, title, album, artist))`) diff --git a/cmd/schemagen/testdata/models.go.txt b/cmd/schemagen/testdata/models.go.txt index 912b297..6714a49 100644 --- a/cmd/schemagen/testdata/models.go.txt +++ b/cmd/schemagen/testdata/models.go.txt @@ -2,7 +2,9 @@ package foobar -import "github.com/scylladb/gocqlx/v2/table" +import ( + "github.com/scylladb/gocqlx/v2/table" +) // Table models. var ( @@ -41,3 +43,24 @@ var ( SortKey: []string{}, }) ) + +type AlbumUserType struct { + Name string + Songwriters []string +} + +type PlaylistsStruct struct { + Album AlbumUserType + Artist string + Id [16]byte + SongId [16]byte + Title string +} +type SongsStruct struct { + Album string + Artist string + Data []byte + Id [16]byte + Tags []string + Title string +} diff --git a/example_test.go b/example_test.go index fd37db0..641f750 100644 --- a/example_test.go +++ b/example_test.go @@ -37,6 +37,7 @@ func TestExample(t *testing.T) { session.ExecStmt(`DROP KEYSPACE examples`) basicCreateAndPopulateKeyspace(t, session) + createAndPopulateKeyspaceAllTypes(t, session) basicReadScyllaVersion(t, session) datatypesBlob(t, session) @@ -154,6 +155,170 @@ func basicCreateAndPopulateKeyspace(t *testing.T, session gocqlx.Session) { } } +// This example shows how to use query builders and table models to build +// queries with all types. It uses "BindStruct" function for parameter binding and "Select" +// function for loading data to a slice. +func createAndPopulateKeyspaceAllTypes(t *testing.T, session gocqlx.Session) { + err := session.ExecStmt(`CREATE KEYSPACE IF NOT EXISTS examples WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}`) + if err != nil { + t.Fatal("create keyspace:", err) + } + + // generated with schemagen + type CheckTypesStruct struct { + AsciI string + BigInt int64 + BloB []byte + BooleaN bool + DatE time.Time + DecimaL inf.Dec + DoublE float64 + DuratioN gocql.Duration + FloaT float32 + Id [16]byte + InT int32 + IneT string + ListInt []int32 + MapIntText map[int32]string + SetInt []int32 + SmallInt int16 + TexT string + TimE time.Duration + TimestamP time.Time + TimeuuiD [16]byte + TinyInt int8 + VarChar string + VarInt int64 + } + + err = session.ExecStmt(`CREATE TABLE IF NOT EXISTS examples.check_types ( + asci_i ascii, + big_int bigint, + blo_b blob, + boolea_n boolean, + dat_e date, + decima_l decimal, + doubl_e double, + duratio_n duration, + floa_t float, + ine_t inet, + in_t int, + small_int smallint, + tex_t text, + tim_e time, + timestam_p timestamp, + timeuui_d timeuuid, + tiny_int tinyint, + id uuid PRIMARY KEY, + var_char varchar, + var_int varint, + map_int_text map, + list_int list, + set_int set)`) + if err != nil { + t.Fatal("create table:", err) + } + + // generated with schemagen + checkTypesTable := table.New(table.Metadata{ + Name: "examples.check_types", + Columns: []string{ + "asci_i", + "big_int", + "blo_b", + "boolea_n", + "dat_e", + "decima_l", + "doubl_e", + "duratio_n", + "floa_t", + "id", + "in_t", + "ine_t", + "list_int", + "map_int_text", + "set_int", + "small_int", + "tex_t", + "tim_e", + "timestam_p", + "timeuui_d", + "tiny_int", + "var_char", + "var_int", + }, + PartKey: []string{"id"}, + SortKey: []string{}, + }) + + // Insert song using query builder. + insertCheckTypes := qb.Insert("examples.check_types"). + Columns("asci_i", "big_int", "blo_b", "boolea_n", "dat_e", "decima_l", "doubl_e", "duratio_n", "floa_t", "ine_t", "in_t", "small_int", "tex_t", "tim_e", "timestam_p", "timeuui_d", "tiny_int", "id", "var_char", "var_int", "map_int_text", "list_int", "set_int").Query(session) + + var byteId [16]byte + id := []byte("756716f7-2e54-4715-9f00-91dcbea6cf50") + copy(byteId[:], id) + + date := time.Date(2021, time.December, 11, 10, 23, 0, 0, time.UTC) + var double float64 = 1.2 + var float float32 = 1.3 + var integer int32 = 123 + listInt := []int32{1, 2, 3} + mapIntStr := map[int32]string{ + 1: "a", + 2: "b", + } + setInt := []int32{2, 4, 6} + var smallInt int16 = 12 + var tinyInt int8 = 14 + var varInt int64 = 20 + + insertCheckTypes.BindStruct(CheckTypesStruct{ + AsciI: "test qscci", + BigInt: 9223372036854775806, //MAXINT64 - 1, + BloB: []byte("this is blob test"), + BooleaN: false, + DatE: date, + DecimaL: *inf.NewDec(1, 1), + DoublE: double, + DuratioN: gocql.Duration{Months: 1, Days: 1, Nanoseconds: 86400}, + FloaT: float, + Id: byteId, + InT: integer, + IneT: "127.0.0.1", + ListInt: listInt, + MapIntText: mapIntStr, + SetInt: setInt, + SmallInt: smallInt, + TexT: "text example", + TimE: 86400000000, + TimestamP: date, + TimeuuiD: gocql.TimeUUID(), + TinyInt: tinyInt, + VarChar: "test varchar", + VarInt: varInt, + }) + if err := insertCheckTypes.ExecRelease(); err != nil { + t.Fatal("ExecRelease() failed:", err) + } + + // Query and displays data. + queryCheckTypes := checkTypesTable.SelectQuery(session) + + queryCheckTypes.BindStruct(&CheckTypesStruct{ + Id: byteId, + }) + + var items []*CheckTypesStruct + if err := queryCheckTypes.Select(&items); err != nil { + t.Fatal("Select() failed:", err) + } + + for _, i := range items { + t.Logf("%+v", *i) + } +} + // This example shows how to load a single value using "Get" function. // Get can also work with UDTs and types that implement gocql marshalling functions. func basicReadScyllaVersion(t *testing.T, session gocqlx.Session) { diff --git a/go.mod b/go.mod index c207cdb..f8ff801 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/scylladb/gocqlx/v2 go 1.17 require ( - github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537 + github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1 github.com/google/go-cmp v0.5.4 github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef github.com/scylladb/go-reflectx v1.0.1 @@ -12,7 +12,7 @@ require ( ) require ( - github.com/golang/snappy v0.0.1 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect ) diff --git a/go.sum b/go.sum index e92d4c1..00894b5 100644 --- a/go.sum +++ b/go.sum @@ -5,9 +5,14 @@ github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dR github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537 h1:NaMut1fdw76YYX/TPinSAbai4DShF5tPort3bHpET6g= github.com/gocql/gocql v0.0.0-20200131111108-92af2e088537/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= +github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1 h1:px9qUCy/RNJNsfCam4m2IxWGxNuimkrioEF0vrrbPsg= +github.com/gocql/gocql v0.0.0-20211015133455-b225f9b53fa1/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.5.4 h1:L8R9j+yAqZuZjsqh/z+F1NCffTKKLShY6zXTItVIZ8M= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=