diff --git a/dependency.go b/dependency.go index d24b465..d26845e 100644 --- a/dependency.go +++ b/dependency.go @@ -201,6 +201,13 @@ func newDefaultDriver(args DriverArgs) (Driver, error) { if err := args.Populator.Populate(&injected); err != nil { return nil, fmt.Errorf("missing dependency for the default queue driver: %w", err) } + driver, err := driverFromDI(args.Populator) + if err != nil { + return nil, fmt.Errorf("error fetching default driver from DI: %w", err) + } + if driver != nil { + return driver, nil + } var redisName string if err := injected.ConfigUnmarshaler.Unmarshal(fmt.Sprintf("queue.%s.redisName", injected.AppName), &redisName); err != nil { return nil, fmt.Errorf("bad configuration: %w", err) @@ -257,3 +264,15 @@ func provideConfig() configOut { }} return configOut{Config: configs} } + +func driverFromDI(populator contract.DIPopulator) (Driver, error) { + var injected struct { + di.In + Driver `optional:"true"` + } + err := populator.Populate(&injected) + if err != nil { + return nil, err + } + return injected.Driver, nil +} diff --git a/dependency_test.go b/dependency_test.go index 977d005..18fdaca 100644 --- a/dependency_test.go +++ b/dependency_test.go @@ -133,3 +133,23 @@ func TestProvideConfigs(t *testing.T) { c := provideConfig() assert.NotEmpty(t, c.Config) } + +type driverPopulator struct{} + +func (d driverPopulator) Populate(target interface{}) error { + graph := di.NewGraph() + graph.Provide(func() Driver { + return mockDriver{} + }) + di.IntoPopulator(graph).Populate(target) + return nil +} + +func TestDriverFromDI(t *testing.T) { + driver, err := newDefaultDriver(DriverArgs{ + Name: "", + Populator: driverPopulator{}, + }) + assert.NoError(t, err) + assert.IsType(t, mockDriver{}, driver) +}