Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(database/gdb): add Raw support for Fields function of gdb.Model #3873

Merged
merged 8 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions contrib/drivers/mysql/mysql_z_unit_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4838,3 +4838,25 @@ func Test_OrderBy_Statement_Generated(t *testing.T) {
t.Assert(rawSql, expectSql)
})
}

func Test_Fields_Raw(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
table := createInitTable()
defer dropTable(table)
one, err := db.Model(table).Fields(gdb.Raw("1")).One()
t.AssertNil(err)
t.Assert(one["1"], 1)

one, err = db.Model(table).Fields(gdb.Raw("2")).One()
t.AssertNil(err)
t.Assert(one["2"], 2)

one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 2).One()
t.AssertNil(err)
t.Assert(one["2"], 2)

one, err = db.Model(table).Fields(gdb.Raw("2")).Where("id", 10000000000).One()
t.AssertNil(err)
t.Assert(len(one), 0)
})
}
1 change: 1 addition & 0 deletions contrib/drivers/sqlite/sqlite_z_unit_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ func Test_Model_AllAndCount(t *testing.T) {
t.Assert(len(result), TableSize)
t.Assert(count, TableSize)
})

// AllAndCount with no data
gtest.C(t, func(t *gtest.T) {
result, count, err := db.Model(table).Where("id<0").AllAndCount(false)
Expand Down
6 changes: 3 additions & 3 deletions database/gdb/gdb_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ func doQuoteString(s, charLeft, charRight string) string {
return gstr.Join(array1, ",")
}

func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) {
fields = []string{}
func getFieldsFromStructOrMap(structOrMap any) (fields []any) {
fields = []any{}
if utils.IsStruct(structOrMap) {
structFields, _ := gstructs.Fields(gstructs.FieldsInput{
Pointer: structOrMap,
Expand All @@ -362,7 +362,7 @@ func getFieldsFromStructOrMap(structOrMap interface{}) (fields []string) {
}
}
} else {
fields = gutil.Keys(structOrMap)
fields = gconv.Interfaces(gutil.Keys(structOrMap))
}
return
}
Expand Down
13 changes: 8 additions & 5 deletions database/gdb/gdb_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ type Model struct {
linkType int // Mark for operation on master or slave.
tablesInit string // Table names when model initialization.
tables string // Operation table names, which can be more than one table names and aliases, like: "user", "user u", "user u, user_detail ud".
fields string // Operation fields, multiple fields joined using char ','.
fieldsEx []string // Excluded operation fields, it here uses slice instead of string type for quick filtering.
fields []any // Operation fields, multiple fields joined using char ','.
fieldsEx []any // Excluded operation fields, it here uses slice instead of string type for quick filtering.
withArray []interface{} // Arguments for With feature.
withAll bool // Enable model association operations on all objects that have "with" tag in the struct.
extraArgs []interface{} // Extra custom arguments for sql, which are prepended to the arguments before sql committed to underlying driver.
Expand Down Expand Up @@ -65,7 +65,7 @@ type ChunkHandler func(result Result, err error) bool
const (
linkTypeMaster = 1
linkTypeSlave = 2
defaultFields = "*"
defaultField = "*"
whereHolderOperatorWhere = 1
whereHolderOperatorAnd = 2
whereHolderOperatorOr = 3
Expand Down Expand Up @@ -132,7 +132,6 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model {
schema: c.schema,
tablesInit: tableStr,
tables: tableStr,
fields: defaultFields,
start: -1,
offset: -1,
filter: true,
Expand Down Expand Up @@ -281,8 +280,12 @@ func (m *Model) Clone() *Model {
newModel.whereBuilder = m.whereBuilder.Clone()
newModel.whereBuilder.model = newModel
// Shallow copy slice attributes.
if n := len(m.fields); n > 0 {
newModel.fields = make([]any, n)
copy(newModel.fields, m.fields)
}
if n := len(m.fieldsEx); n > 0 {
newModel.fieldsEx = make([]string, n)
newModel.fieldsEx = make([]any, n)
copy(newModel.fieldsEx, m.fieldsEx)
}
if n := len(m.extraArgs); n > 0 {
Expand Down
50 changes: 24 additions & 26 deletions database/gdb/gdb_model_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (m *Model) Fields(fieldNamesOrMapStruct ...interface{}) *Model {
return m
}
model := m.getModel()
return model.appendFieldsByStr(gstr.Join(fields, ","))
return model.appendToFields(fields...)
}

// FieldsPrefix performs as function Fields but add extra prefix for each field.
Expand All @@ -45,9 +45,11 @@ func (m *Model) FieldsPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...inte
if len(fields) == 0 {
return m
}
gstr.PrefixArray(fields, prefixOrAlias+".")
for i, field := range fields {
fields[i] = prefixOrAlias + "." + gconv.String(field)
}
model := m.getModel()
return model.appendFieldsByStr(gstr.Join(fields, ","))
return model.appendToFields(fields...)
}

// FieldsEx appends `fieldNamesOrMapStruct` to the excluded operation fields of the model,
Expand Down Expand Up @@ -84,7 +86,9 @@ func (m *Model) FieldsExPrefix(prefixOrAlias string, fieldNamesOrMapStruct ...in
m.getTableNameByPrefixOrAlias(prefixOrAlias),
fieldNamesOrMapStruct...,
)
gstr.PrefixArray(model.fieldsEx, prefixOrAlias+".")
for i, field := range model.fieldsEx {
model.fieldsEx[i] = prefixOrAlias + "." + gconv.String(field)
}
return model
}

Expand All @@ -95,7 +99,7 @@ func (m *Model) FieldCount(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`COUNT(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -107,7 +111,7 @@ func (m *Model) FieldSum(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`SUM(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -119,7 +123,7 @@ func (m *Model) FieldMin(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`MIN(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -131,7 +135,7 @@ func (m *Model) FieldMax(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`MAX(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand All @@ -143,7 +147,7 @@ func (m *Model) FieldAvg(column string, as ...string) *Model {
asStr = fmt.Sprintf(` AS %s`, m.db.GetCore().QuoteWord(as[0]))
}
model := m.getModel()
return model.appendFieldsByStr(
return model.appendToFields(
fmt.Sprintf(`AVG(%s)%s`, m.QuoteWord(column), asStr),
)
}
Expand Down Expand Up @@ -218,7 +222,7 @@ func (m *Model) HasField(field string) (bool, error) {
}

// getFieldsFrom retrieves, filters and returns fields name from table `table`.
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interface{}) []string {
func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...any) []any {
length := len(fieldNamesOrMapStruct)
if length == 0 {
return nil
Expand All @@ -227,21 +231,21 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
// String slice.
case length >= 2:
return m.mappingAndFilterToTableFields(
table, gconv.Strings(fieldNamesOrMapStruct), true,
table, fieldNamesOrMapStruct, true,
)

// It needs type asserting.
case length == 1:
structOrMap := fieldNamesOrMapStruct[0]
switch r := structOrMap.(type) {
case string:
return m.mappingAndFilterToTableFields(table, []string{r}, false)
return m.mappingAndFilterToTableFields(table, []any{r}, false)

case []string:
return m.mappingAndFilterToTableFields(table, r, true)
return m.mappingAndFilterToTableFields(table, gconv.Interfaces(r), true)

case Raw, *Raw:
return []string{gconv.String(structOrMap)}
return []any{structOrMap}

default:
return m.mappingAndFilterToTableFields(table, getFieldsFromStructOrMap(structOrMap), true)
Expand All @@ -252,19 +256,13 @@ func (m *Model) filterFieldsFrom(table string, fieldNamesOrMapStruct ...interfac
}
}

func (m *Model) appendFieldsByStr(fields string) *Model {
if fields != "" {
model := m.getModel()
if model.fields == defaultFields {
model.fields = ""
}
if model.fields != "" {
model.fields += ","
}
model.fields += fields
return model
func (m *Model) appendToFields(fields ...any) *Model {
if len(fields) == 0 {
return m
}
return m
model := m.getModel()
model.fields = append(model.fields, fields...)
return model
}

func (m *Model) isFieldInFieldsEx(field string) bool {
Expand Down
67 changes: 45 additions & 22 deletions database/gdb/gdb_model_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (m *Model) AllAndCount(useFieldForCount bool) (result Result, totalCount in

// If useFieldForCount is false, set the fields to a constant value of 1 for counting
if !useFieldForCount {
countModel.fields = "1"
countModel.fields = []any{Raw("1")}
}

// Get the total count of records
Expand Down Expand Up @@ -178,7 +178,7 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) {
func (m *Model) doStruct(pointer interface{}, where ...interface{}) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") {
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(v.Interface())
} else {
Expand Down Expand Up @@ -214,7 +214,7 @@ func (m *Model) doStruct(pointer interface{}, where ...interface{}) error {
func (m *Model) doStructs(pointer interface{}, where ...interface{}) error {
model := m
// Auto selecting fields by struct attributes.
if len(model.fieldsEx) == 0 && (model.fields == "" || model.fields == "*") {
if len(model.fieldsEx) == 0 && len(model.fields) == 0 {
if v, ok := pointer.(reflect.Value); ok {
model = m.Fields(
reflect.New(
Expand Down Expand Up @@ -316,7 +316,7 @@ func (m *Model) ScanAndCount(pointer interface{}, totalCount *int, useFieldForCo
countModel := m.Clone()
// If useFieldForCount is false, set the fields to a constant value of 1 for counting
if !useFieldForCount {
countModel.fields = "1"
countModel.fields = []any{Raw("1")}
}

// Get the total count of records
Expand All @@ -343,7 +343,7 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string,
if err != nil {
return err
}
if m.fields != defaultFields || len(m.fieldsEx) != 0 {
if len(m.fields) > 0 || len(m.fieldsEx) != 0 {
// There are custom fields.
result, err = m.All()
} else {
Expand Down Expand Up @@ -604,7 +604,9 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{})
}

// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(ctx context.Context, queryType queryType, sql string, args ...interface{}) (result Result, err error) {
func (m *Model) doGetAllBySql(
ctx context.Context, queryType queryType, sql string, args ...interface{},
) (result Result, err error) {
if result, err = m.getSelectResultFromCache(ctx, sql, args...); err != nil || result != nil {
return
}
Expand Down Expand Up @@ -635,10 +637,10 @@ func (m *Model) getFormattedSqlAndArgs(
switch queryType {
case queryTypeCount:
queryFields := "COUNT(1)"
if m.fields != "" && m.fields != "*" {
if len(m.fields) > 0 {
// DO NOT quote the m.fields here, in case of fields like:
// DISTINCT t.user_id uid
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields)
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.getFieldsAsStr())
}
// Raw SQL Model.
if m.rawSql != "" {
Expand Down Expand Up @@ -691,29 +693,50 @@ func (m *Model) getAutoPrefix() string {
return autoPrefix
}

func (m *Model) getFieldsAsStr() string {
var (
fieldsStr string
core = m.db.GetCore()
)
for _, v := range m.fields {
field := gconv.String(v)
switch {
case gstr.ContainsAny(field, "()"):
case gstr.ContainsAny(field, ". "):
default:
switch v.(type) {
case Raw, *Raw:
default:
field = core.QuoteString(field)
}
}
if fieldsStr != "" {
fieldsStr += ","
}
fieldsStr += field
}
return fieldsStr
}

// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will
// really be committed to underlying database driver.
func (m *Model) getFieldsFiltered() string {
if len(m.fieldsEx) == 0 {
// No filtering, containing special chars.
if gstr.ContainsAny(m.fields, "()") {
return m.fields
}
// No filtering.
if !gstr.ContainsAny(m.fields, ". ") {
return m.db.GetCore().QuoteString(m.fields)
}
return m.fields
if len(m.fieldsEx) == 0 && len(m.fields) == 0 {
return defaultField
}
if len(m.fieldsEx) == 0 && len(m.fields) > 0 {
return m.getFieldsAsStr()
}
var (
fieldsArray []string
fieldsExSet = gset.NewStrSetFrom(m.fieldsEx)
fieldsExSet = gset.NewStrSetFrom(gconv.Strings(m.fieldsEx))
)
if m.fields != "*" {
if len(m.fields) > 0 {
// Filter custom fields with fieldEx.
fieldsArray = make([]string, 0, 8)
for _, v := range gstr.SplitAndTrim(m.fields, ",") {
fieldsArray = append(fieldsArray, v[gstr.PosR(v, "-")+1:])
for _, v := range m.fields {
field := gconv.String(v)
fieldsArray = append(fieldsArray, field[gstr.PosR(field, "-")+1:])
}
} else {
if gstr.Contains(m.tables, " ") {
Expand Down
Loading
Loading