diff --git a/callbacks.go b/callbacks.go index 139a83bf..f488395e 100644 --- a/callbacks.go +++ b/callbacks.go @@ -4,6 +4,8 @@ import ( "fmt" "reflect" "strings" + + "github.com/pkg/errors" ) type bindings map[reflect.Type]func() (reflect.Value, error) @@ -24,6 +26,30 @@ func (b bindings) add(values ...interface{}) bindings { return b } +func (b bindings) addTo(impl, iface interface{}) { + valueOf := reflect.ValueOf(impl) + b[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil } +} + +func (b bindings) addProvider(provider interface{}) error { + pv := reflect.ValueOf(provider) + t := pv.Type() + if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { + return errors.Errorf("%T must be a function with the signature func()(T, error)", provider) + } + rt := pv.Type().Out(0) + b[rt] = func() (reflect.Value, error) { + out := pv.Call(nil) + errv := out[1] + var err error + if !errv.IsNil() { + err = errv.Interface().(error) // nolint + } + return out[0], err + } + return nil +} + // Clone and add values. func (b bindings) clone() bindings { out := make(bindings, len(b)) diff --git a/context.go b/context.go index 83368d10..21c4bf6e 100644 --- a/context.go +++ b/context.go @@ -113,8 +113,15 @@ func (c *Context) Bind(args ...interface{}) { // // BindTo(impl, (*MyInterface)(nil)) func (c *Context) BindTo(impl, iface interface{}) { - valueOf := reflect.ValueOf(impl) - c.bindings[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil } + c.bindings.addTo(impl, iface) +} + +// BindToProvider allows binding of provider functions. +// +// This is useful when the Run() function of different commands require different values that may +// not all be initialisable from the main() function. +func (c *Context) BindToProvider(provider interface{}) error { + return c.bindings.addProvider(provider) } // Value returns the value for a particular path element. diff --git a/options.go b/options.go index 6036d91f..2b69a17f 100644 --- a/options.go +++ b/options.go @@ -180,8 +180,7 @@ func Bind(args ...interface{}) Option { // BindTo(impl, (*iface)(nil)) func BindTo(impl, iface interface{}) Option { return OptionFunc(func(k *Kong) error { - valueOf := reflect.ValueOf(impl) - k.bindings[reflect.TypeOf(iface).Elem()] = func() (reflect.Value, error) { return valueOf, nil } + k.bindings.addTo(impl, iface) return nil }) } @@ -192,22 +191,7 @@ func BindTo(impl, iface interface{}) Option { // not all be initialisable from the main() function. func BindToProvider(provider interface{}) Option { return OptionFunc(func(k *Kong) error { - pv := reflect.ValueOf(provider) - t := pv.Type() - if t.Kind() != reflect.Func || t.NumIn() != 0 || t.NumOut() != 2 || t.Out(1) != reflect.TypeOf((*error)(nil)).Elem() { - return errors.Errorf("%T must be a function with the signature func()(T, error)", provider) - } - rt := pv.Type().Out(0) - k.bindings[rt] = func() (reflect.Value, error) { - out := pv.Call(nil) - errv := out[1] - var err error - if !errv.IsNil() { - err = errv.Interface().(error) // nolint - } - return out[0], err - } - return nil + return k.bindings.addProvider(provider) }) }