diff --git a/generate.go b/generate.go index f290511..4d90290 100644 --- a/generate.go +++ b/generate.go @@ -9,7 +9,6 @@ import ( "io" "os" "reflect" - "strconv" "strings" ) @@ -284,7 +283,7 @@ func getScaleType(parentType reflect.Type, field reflect.StructField) (scaleType case reflect.Bool: return scaleType{Name: "Bool"}, nil case reflect.String: - maxElements, err := getMaxElements(field.Tag) + maxElements, err := maxScaleElements(field.Tag) if err != nil { return scaleType{}, fmt.Errorf("scale tag has incorrect max value: %w", err) } @@ -318,7 +317,7 @@ func getScaleType(parentType reflect.Type, field reflect.StructField) (scaleType // []string return scaleType{}, errors.New("string slices are not supported") } - maxElements, err := getMaxElements(field.Tag) + maxElements, err := maxScaleElements(field.Tag) if err != nil { return scaleType{}, fmt.Errorf("scale tag has incorrect max value: %w", err) } @@ -338,41 +337,6 @@ func getScaleType(parentType reflect.Type, field reflect.StructField) (scaleType return scaleType{}, fmt.Errorf("type %v is not supported", field.Type.Kind()) } -func getMaxElements(tag reflect.StructTag) (uint32, error) { - scaleTagValue, exists := tag.Lookup("scale") - if !exists || scaleTagValue == "" { - return 0, nil - } - if scaleTagValue == "" { - return 0, errors.New("scale tag is not defined") - } - pairs := strings.Split(scaleTagValue, ",") - if len(pairs) == 0 { - return 0, errors.New("no max value found in scale tag") - } - var maxElementsStr string - for _, pair := range pairs { - pair = strings.TrimSpace(pair) - data := strings.Split(pair, "=") - if len(data) < 2 { - continue - } - if data[0] != "max" { - continue - } - maxElementsStr = strings.TrimSpace(data[1]) - break - } - if maxElementsStr == "" { - return 0, errors.New("no max value found in scale tag") - } - maxElements, err := strconv.Atoi(maxElementsStr) - if err != nil { - return 0, fmt.Errorf("parsing max value: %w", err) - } - return uint32(maxElements), nil -} - func getTemplate(stype scaleType) temp { switch { case stype.Name == "StructArray": diff --git a/tag.go b/tag.go new file mode 100644 index 0000000..f622893 --- /dev/null +++ b/tag.go @@ -0,0 +1,87 @@ +package scale + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +// MaxScaleElementsForField returns the max number of elements for the specified field in +// the struct passed as the v argument based on the 'scale' tag. It returns an error if v +// is not a structure, if max is not specified for the field, the field doesn't exist or +// there's a problem parsing the tag. +func MaxScaleElementsForField(v any, fieldName string) (uint32, error) { + typ := reflect.TypeOf(v) + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + if typ.Kind() != reflect.Struct { + return 0, errors.New("bad value type") + } + f, found := typ.FieldByName(fieldName) + if !found { + return 0, fmt.Errorf("unknown field %q in %T", fieldName, v) + } + maxElements, err := maxScaleElements(f.Tag) + if err != nil { + return 0, fmt.Errorf("error getting field tag %q in %T: %w", fieldName, v, err) + } + if maxElements == 0 { + return 0, fmt.Errorf("no max in the scale tag for field %q in %T", fieldName, v) + } + return maxElements, nil +} + +// MaxScaleElements is a generic version of GetMaxElementsFromValue that uses the specified +// type instead of a struct value. +func MaxScaleElements[T any](fieldName string) (uint32, error) { + var v T + return MaxScaleElementsForField(v, fieldName) +} + +// MustGetMaxElements is the same as GetMaxElements, but returns just the max tag value +// and panics in case of an error. +func MustGetMaxElements[T any](fieldName string) uint32 { + maxElements, err := MaxScaleElements[T](fieldName) + if err != nil { + panic(err) + } + return maxElements +} + +func maxScaleElements(tag reflect.StructTag) (uint32, error) { + scaleTagValue, exists := tag.Lookup("scale") + if !exists || scaleTagValue == "" { + return 0, nil + } + if scaleTagValue == "" { + return 0, errors.New("scale tag is not defined") + } + pairs := strings.Split(scaleTagValue, ",") + if len(pairs) == 0 { + return 0, errors.New("no max value found in scale tag") + } + var maxElementsStr string + for _, pair := range pairs { + pair = strings.TrimSpace(pair) + data := strings.Split(pair, "=") + if len(data) < 2 { + continue + } + if data[0] != "max" { + continue + } + maxElementsStr = strings.TrimSpace(data[1]) + break + } + if maxElementsStr == "" { + return 0, errors.New("no max value found in scale tag") + } + maxElements, err := strconv.Atoi(maxElementsStr) + if err != nil { + return 0, fmt.Errorf("parsing max value: %w", err) + } + return uint32(maxElements), nil +} diff --git a/tag_test.go b/tag_test.go new file mode 100644 index 0000000..967bdc9 --- /dev/null +++ b/tag_test.go @@ -0,0 +1,43 @@ +package scale + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +type Foo struct { + Name string `scale:"max=64"` + Items []int `scale:"max=4242"` + Bad []int `scale:"max=4242x"` + Bad1 []int `scale:"abc=4242"` + NoTag []int +} + +func TestGetMaxElements(t *testing.T) { + n, err := MaxScaleElements[Foo]("Name") + require.NoError(t, err) + require.Equal(t, uint32(64), n) + n, err = MaxScaleElements[Foo]("Items") + require.NoError(t, err) + require.Equal(t, uint32(4242), n) + n, err = MaxScaleElements[*Foo]("Name") + require.NoError(t, err) + require.Equal(t, uint32(64), n) + n, err = MaxScaleElements[*Foo]("Items") + require.NoError(t, err) + require.Equal(t, uint32(4242), n) + require.Equal(t, uint32(64), MustGetMaxElements[Foo]("Name")) + require.Equal(t, uint32(4242), MustGetMaxElements[Foo]("Items")) + + _, err = MaxScaleElements[Foo]("NoSuchField") + require.Error(t, err) + _, err = MaxScaleElements[int]("Name") + require.Error(t, err) + _, err = MaxScaleElements[int]("Bad") + require.Error(t, err) + _, err = MaxScaleElements[int]("Bad1") + require.Error(t, err) + _, err = MaxScaleElements[int]("NoTag") + require.Error(t, err) +}