Skip to content

Commit

Permalink
feat(database/gdb): add Raw support for Fields function of `gdb.M…
Browse files Browse the repository at this point in the history
…odel` (#3873)
  • Loading branch information
gqcn authored Oct 21, 2024
1 parent b1d875a commit 7dd38a1
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 71 deletions.
22 changes: 22 additions & 0 deletions contrib/drivers/mysql/mysql_z_unit_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4838,3 +4838,25 @@ func Test_OrderBy_Statement_Generated(t *testing.T) {
t.Assert(rawSql, expectSql)
})
}

func Test_Fields_Raw(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
table := createInitTable()
defer dropTable(table)
one, err := db.Model(table).Fields(gdb.Raw("1")).One()
t.AssertNil(err)
t.Assert(one["1"], 1)

one, err = db.Model(table).Fields(gdb.Raw("2")).One()
t.AssertNil(err)
t.Assert(one["2"], 2)

one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 2).One()
t.AssertNil(err)
t.Assert(one["2"], 2)

one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 10000000000).One()
t.AssertNil(err)
t.Assert(len(one), 0)
})
}
1 change: 1 addition & 0 deletions contrib/drivers/sqlite/sqlite_z_unit_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ func Test_Model_AllAndCount(t *testing.T) {
t.Assert(len(result), TableSize)
t.Assert(count, TableSize)
})

// AllAndCount with no data
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).Where("id<0").AllAndCount(false)
Expand Down
6 changes: 3 additions & 3 deletions database/gdb/gdb_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ func doQuoteString(s, charLeft, charRight string) string {
return gstr.Join(array1, ",")
}

func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) {
fields = []string{}
func getFieldsFromStructOrMap(structOrMap any) (fields []any) {
fields = []any{}
if utils.IsStruct(structOrMap) {
structFields, _ := gstructs.Fields(gstructs.FieldsInput{
Pointer: structOrMap,
Expand All @@ -362,7 +362,7 @@ func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) {
}
}
} else {
fields = gutil.Keys(structOrMap)
fields = gconv.Interfaces(gutil.Keys(structOrMap))
}
return
}
Expand Down
13 changes: 8 additions & 5 deletions database/gdb/gdb_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ type Model struct {
linkType int // Mark for operation on master or slave.
tablesInit string // Table names when model initialization.
tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud".
fields string // Operation fields, multiple fields joined using char ','.
fieldsEx []string // Excluded operation fields, it here uses slice instead of string type for quick filtering.
fields []any // Operation fields, multiple fields joined using char ','.
fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering.
withArray []interface{} // Arguments for With feature.
withAll bool // Enable model association operations on all objects that have "with" tag in the struct.
extraArgs []interface{} // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver.
Expand Down Expand Up @@ -65,7 +65,7 @@ type ChunkHandler func(result Result, err error) bool
const (
linkTypeMaster = 1
linkTypeSlave = 2
defaultFields = "*"
defaultField = "*"
whereHolderOperatorWhere = 1
whereHolderOperatorAnd = 2
whereHolderOperatorOr = 3
Expand Down Expand Up @@ -132,7 +132,6 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
schema: c.schema,
tablesInit: tableStr,
tables: tableStr,
fields: defaultFields,
start: -1,
offset: -1,
filter: true,
Expand Down Expand Up @@ -281,8 +280,12 @@ func (m *Model) Clone() *Model {
newModel.whereBuilder = m.whereBuilder.Clone()
newModel.whereBuilder.model = newModel
// Shallow copy slice attributes.
if n := len(m.fields); n > 0 {
newModel.fields = make([]any, n)
copy(newModel.fields, m.fields)
}
if n := len(m.fieldsEx); n > 0 {
newModel.fieldsEx = make([]string, n)
newModel.fieldsEx = make([]any, n)
copy(newModel.fieldsEx, m.fieldsEx)
}
if n := len(m.extraArgs); n > 0 {
Expand Down
50 changes: 24 additions & 26 deletions database/gdb/gdb_model_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (m *Model) Fields(fieldNamesOrMapStruct ...interface{}) *Model {
return m
}
model := m.getModel()
return model.appendFieldsByStr(gstr.Join(fields, ","))
return model.appendToFields(fields...)
}

// FieldsPrefix performs as function Fields but add extra prefix for each field.
Expand All @@ -45,9 +45,11 @@ func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...inte
if len(fields) == 0 {
return m
}
gstr.PrefixArray(fields, prefixOrAlias+".")
for i, field := range fields {
fields[i] = prefixOrAlias + "." + gconv.String(field)
}
model := m.getModel()
return model.appendFieldsByStr(gstr.Join(fields, ","))
return model.appendToFields(fields...)
}

// FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model,
Expand Down Expand Up @@ -84,7 +86,9 @@ func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...in
m.getTableNameByPrefixOrAlias(prefixOrAlias),
fieldNamesOrMapStruct...,
)
gstr.PrefixArray(model.fieldsEx, prefixOrAlias+".")
for i, field := range model.fieldsEx {
model.fieldsEx[i] = prefixOrAlias + "." + gconv.String(field)
}
return model
}

Expand All @@ -95,7 +99,7 @@ func (m *Model) FieldCount(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -107,7 +111,7 @@ func (m *Model) FieldSum(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -119,7 +123,7 @@ func (m *Model) FieldMin(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -131,7 +135,7 @@ func (m *Model) FieldMax(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -143,7 +147,7 @@ func (m *Model) FieldAvg(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand Down Expand Up @@ -218,7 +222,7 @@ func (m *Model) HasField(field string) (bool, error) {
}

// getFieldsFrom retrieves, filters and returns fields name from table `table`.
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interface{}) []string {
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any {
length := len(fieldNamesOrMapStruct)
if length == 0 {
return nil
Expand All @@ -227,21 +231,21 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
// String slice.
case length >= 2:
return m.mappingAndFilterToTableFields(
table, gconv.Strings(fieldNamesOrMapStruct), true,
table, fieldNamesOrMapStruct, true,
)

// It needs type asserting.
case length == 1:
structOrMap := fieldNamesOrMapStruct[0]
switch r := structOrMap.(type) {
case string:
return m.mappingAndFilterToTableFields(table, []string{r}, false)
return m.mappingAndFilterToTableFields(table, []any{r}, false)

case []string:
return m.mappingAndFilterToTableFields(table, r, true)
return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true)

case Raw, *Raw:
return []string{gconv.String(structOrMap)}
return []any{structOrMap}

default:
return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true)
Expand All @@ -252,19 +256,13 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
}
}

func (m *Model) appendFieldsByStr(fields string) *Model {
if fields != "" {
model := m.getModel()
if model.fields == defaultFields {
model.fields = ""
}
if model.fields != "" {
model.fields += ","
}
model.fields += fields
return model
func (m *Model) appendToFields(fields ...any) *Model {
if len(fields) == 0 {
return m
}
return m
model := m.getModel()
model.fields = append(model.fields, fields...)
return model
}

func (m *Model) isFieldInFieldsEx(field string) bool {
Expand Down
67 changes: 45 additions & 22 deletions database/gdb/gdb_model_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in

// If useFieldForCount is false, set the fields to a constant value of 1 for counting
if !useFieldForCount {
countModel.fields = "1"
countModel.fields = []any{Raw("1")}
}

// Get the total count of records
Expand Down Expand Up @@ -178,7 +178,7 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
func (m *Model) doStruct(pointer interface{}, where ...interface{}) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") {
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(v.Interface())
} else {
Expand Down Expand Up @@ -214,7 +214,7 @@ func (m *Model) doStruct(pointer interface{}, where ...interface{}) error {
func (m *Model) doStructs(pointer interface{}, where ...interface{}) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") {
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(
reflect.New(
Expand Down Expand Up @@ -316,7 +316,7 @@ func (m *Model) ScanAndCount(pointer interface{}, totalCount *int, useFieldForCo
countModel := m.Clone()
// If useFieldForCount is false, set the fields to a constant value of 1 for counting
if !useFieldForCount {
countModel.fields = "1"
countModel.fields = []any{Raw("1")}
}

// Get the total count of records
Expand All @@ -343,7 +343,7 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string,
if err != nil {
return err
}
if m.fields != defaultFields || len(m.fieldsEx) != 0 {
if len(m.fields) > 0 || len(m.fieldsEx) != 0 {
// There are custom fields.
result, err = m.All()
} else {
Expand Down Expand Up @@ -604,7 +604,9 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{})
}

// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(ctx context.Context, queryType queryType, sql string, args ...interface{}) (result Result, err error) {
func (m *Model) doGetAllBySql(
ctx context.Context, queryType queryType, sql string, args ...interface{},
) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return
}
Expand Down Expand Up @@ -635,10 +637,10 @@ func (m *Model) getFormattedSqlAndArgs(
switch queryType {
case queryTypeCount:
queryFields := "COUNT(1)"
if m.fields != "" && m.fields != "*" {
if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like:
// DISTINCT t.user_id uid
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields)
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr())
}
// Raw SQL Model.
if m.rawSql != "" {
Expand Down Expand Up @@ -691,29 +693,50 @@ func (m *Model) getAutoPrefix() string {
return autoPrefix
}

func (m *Model) getFieldsAsStr() string {
var (
fieldsStr string
core = m.db.GetCore()
)
for _, v := range m.fields {
field := gconv.String(v)
switch {
case gstr.ContainsAny(field, "()"):
case gstr.ContainsAny(field, ". "):
default:
switch v.(type) {
case Raw, *Raw:
default:
field = core.QuoteString(field)
}
}
if fieldsStr != "" {
fieldsStr += ","
}
fieldsStr += field
}
return fieldsStr
}

// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will
// really be committed to underlying database driver.
func (m *Model) getFieldsFiltered() string {
if len(m.fieldsEx) == 0 {
// No filtering, containing special chars.
if gstr.ContainsAny(m.fields, "()") {
return m.fields
}
// No filtering.
if !gstr.ContainsAny(m.fields, ". ") {
return m.db.GetCore().QuoteString(m.fields)
}
return m.fields
if len(m.fieldsEx) == 0 && len(m.fields) == 0 {
return defaultField
}
if len(m.fieldsEx) == 0 && len(m.fields) > 0 {
return m.getFieldsAsStr()
}
var (
fieldsArray []string
fieldsExSet = gset.NewStrSetFrom(m.fieldsEx)
fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx))
)
if m.fields != "*" {
if len(m.fields) > 0 {
// Filter custom fields with fieldEx.
fieldsArray = make([]string, 0, 8)
for _, v := range gstr.SplitAndTrim(m.fields, ",") {
fieldsArray = append(fieldsArray, v[gstr.PosR(v, "-")+1:])
for _, v := range m.fields {
field := gconv.String(v)
fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:])
}
} else {
if gstr.Contains(m.tables, " ") {
Expand Down
Loading

0 comments on commit 7dd38a1

Please sign in to comment.