Skip to content

Commit

Permalink
feature: bring custom type and id column name to polymorphism (#6716)
Browse files Browse the repository at this point in the history
* feature: bring custom type and id column name to polymorphism

* relationship: better returns for hasPolymorphicRelation

* fix: tests
  • Loading branch information
alexisvisco committed Dec 15, 2023
1 parent b9ebdb1 commit a2cac75
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 53 deletions.
67 changes: 53 additions & 14 deletions schema/relationship.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return nil
}

if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field, polymorphic)
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
Expand All @@ -89,7 +89,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
}
}

Expand Down Expand Up @@ -124,6 +125,20 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation
}

// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}

_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]

return hasType && hasId
}

func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
Expand Down Expand Up @@ -169,23 +184,41 @@ func (schema *Schema) setRelation(relation *Relationship) {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]

relation.Polymorphic = &Polymorphic{
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
Value: schema.Table,
}

var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)

if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}

if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}

relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]

if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}

if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}

if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}

if schema.err == nil {
Expand All @@ -197,12 +230,14 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
}
}

if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}

Expand Down Expand Up @@ -317,7 +352,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`,
})

if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
Expand Down Expand Up @@ -436,7 +472,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
}
}

Expand Down Expand Up @@ -492,7 +529,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu

lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}

for _, name := range lookUpNames {
Expand Down
187 changes: 187 additions & 0 deletions schema/relationship_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,193 @@ func TestEmbeddedHas(t *testing.T) {
})
}

func TestPolymorphic(t *testing.T) {
t.Run("has one", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}

type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;"`
}

s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}

checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})

t.Run("has one with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}

type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"`
}

s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}

checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})

t.Run("has one with only polymorphic type", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
Type string
}

type Cat struct {
ID int
Name string
Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"`
}

s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}

checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toy": {
Name: "Toy",
Type: schema.HasOne,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})

t.Run("has many", func(t *testing.T) {
type Toy struct {
ID int
Name string
OwnerID int
OwnerType string
}

type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphic:Owner;"`
}

s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}

checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
References: []Reference{
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})

t.Run("has many with custom polymorphic type and id", func(t *testing.T) {
type Toy struct {
ID int
Name string
RefId int
Type string
}

type Cat struct {
ID int
Name string
Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"`
}

s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("Failed to parse schema, got error %v", err)
}

checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
"Cat": {
Relations: map[string]Relation{
"Toys": {
Name: "Toys",
Type: schema.HasMany,
Schema: "User",
FieldSchema: "Toy",
Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"},
References: []Reference{
{ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"},
},
},
},
},
})
})
}

func TestEmbeddedBelongsTo(t *testing.T) {
type Country struct {
ID int `gorm:"primaryKey"`
Expand Down
11 changes: 10 additions & 1 deletion tests/associations_has_many_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,30 @@ func TestPolymorphicHasManyAssociation(t *testing.T) {
func TestPolymorphicHasManyAssociationForSlice(t *testing.T) {
users := []User{
*GetUser("slice-hasmany-1", Config{Toys: 2}),
*GetUser("slice-hasmany-2", Config{Toys: 0}),
*GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}),
*GetUser("slice-hasmany-3", Config{Toys: 4}),
}

DB.Create(&users)

// Count
AssertAssociationCount(t, users, "Toys", 6, "")
AssertAssociationCount(t, users, "Tools", 2, "")

// Find
var toys []Toy
if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 {
t.Errorf("toys count should be %v, but got %v", 6, len(toys))
}

// Find Tools (polymorphic with custom type and id)
var tools []Tools
DB.Model(&users).Association("Tools").Find(&tools)

if len(tools) != 2 {
t.Errorf("tools count should be %v, but got %v", 2, len(tools))
}

// Append
DB.Model(&users).Association("Toys").Append(
&Toy{Name: "toy-slice-append-1"},
Expand Down
Loading

0 comments on commit a2cac75

Please sign in to comment.