Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Fix for detecting names using a slice of models
Browse files Browse the repository at this point in the history
  • Loading branch information
mrz1836 committed Apr 4, 2022
1 parent fe4481e commit 8225116
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
14 changes: 7 additions & 7 deletions datastore/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ func (c *Client) incrementWithMongo(
return
}

// getWithMongo will get a given struct from MongoDB
// getWithMongo will get given struct(s) from MongoDB
func (c *Client) getWithMongo(
ctx context.Context,
model interface{},
models interface{},
conditions map[string]interface{},
) error {
queryConditions := getMongoQueryConditions(model, conditions)
collectionName := utils.GetModelTableName(model)
queryConditions := getMongoQueryConditions(models, conditions)
collectionName := utils.GetModelTableName(models)
if collectionName == nil {
return ErrUnknownCollection
}
Expand All @@ -132,7 +132,7 @@ func (c *Client) getWithMongo(
setPrefix(c.options.mongoDBConfig.TablePrefix, *collectionName),
)

if utils.IsModelSlice(model) {
if utils.IsModelSlice(models) {
c.DebugLog(fmt.Sprintf(logLine, "findMany", *collectionName, queryConditions))

cursor, err := collection.Find(ctx, queryConditions)
Expand All @@ -145,7 +145,7 @@ func (c *Client) getWithMongo(
return cursor.Err()
}

if err = cursor.All(ctx, model); err != nil {
if err = cursor.All(ctx, models); err != nil {
return err
}
} else {
Expand All @@ -160,7 +160,7 @@ func (c *Client) getWithMongo(
return result.Err()
}

if err := result.Decode(model); err != nil {
if err := result.Decode(models); err != nil {
c.DebugLog(fmt.Sprintf(logLine, "result err", *collectionName, err))
return err
}
Expand Down
6 changes: 4 additions & 2 deletions utils/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ func GetModelName(model interface{}) *string {
}

// Model is a pointer
if reflect.ValueOf(model).Type().Kind() == reflect.Ptr {
k := GetModelType(model).Kind()
if reflect.ValueOf(model).Type().Kind() == reflect.Ptr && k != reflect.Struct {
if m, ok := model.(checkForMethod); ok {
name := m.GetModelName()
return &name
Expand All @@ -58,7 +59,8 @@ func GetModelTableName(model interface{}) *string {
}

// Model is a pointer
if reflect.ValueOf(model).Type().Kind() == reflect.Ptr {
k := GetModelType(model).Kind()
if reflect.ValueOf(model).Type().Kind() == reflect.Ptr && k != reflect.Struct {
if m, ok := model.(checkForMethod); ok {
name := m.GetModelTableName()
return &name
Expand Down
12 changes: 12 additions & 0 deletions utils/models_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ func TestGetModelName(t *testing.T) {
assert.Equal(t, testModelName, *name)
})

t.Run("models are set - value", func(t *testing.T) {
tm := &[]testModel{{Field: testModelName}}
name := GetModelName(tm)
assert.Equal(t, testModelName, *name)
})

t.Run("model does not have method - pointer", func(t *testing.T) {
tm := &badModel{}
name := GetModelName(tm)
Expand Down Expand Up @@ -212,6 +218,12 @@ func TestGetModelTableName(t *testing.T) {
assert.Equal(t, testTableName, *name)
})

t.Run("models are set - value", func(t *testing.T) {
tm := &[]testModel{{Field: testModelName}}
name := GetModelTableName(tm)
assert.Equal(t, testModelName, *name)
})

t.Run("model does not have method - pointer", func(t *testing.T) {
tm := &badModel{}
name := GetModelTableName(tm)
Expand Down

0 comments on commit 8225116

Please sign in to comment.