diff --git a/internal/log/log.go b/internal/log/log.go index d3e0e0fb1d..a887721c2d 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -36,15 +36,17 @@ var ( logger ddtrace.Logger = &defaultLogger{l: log.New(os.Stderr, "", log.LstdFlags)} ) -// UseLogger sets l as the active logger and returns the previously configured -// logger. -func UseLogger(l ddtrace.Logger) ddtrace.Logger { +// UseLogger sets l as the active logger and returns a function to restore the +// previous logger. The return value is mostly useful when testing. +func UseLogger(l ddtrace.Logger) (undo func()) { Flush() mu.Lock() defer mu.Unlock() old := logger logger = l - return old + return func() { + logger = old + } } // SetLevel sets the given lvl for logging. diff --git a/profiler/profiler_test.go b/profiler/profiler_test.go index 59005989f8..4e38cb54ed 100644 --- a/profiler/profiler_test.go +++ b/profiler/profiler_test.go @@ -62,8 +62,7 @@ func TestStart(t *testing.T) { t.Run("options/GoodAPIKey/Agent", func(t *testing.T) { rl := &log.RecordLogger{} - old := log.UseLogger(rl) - defer log.UseLogger(old) + defer log.UseLogger(rl)() err := Start(WithAPIKey("12345678901234567890123456789012")) defer Stop() @@ -75,8 +74,7 @@ func TestStart(t *testing.T) { t.Run("options/GoodAPIKey/Agentless", func(t *testing.T) { rl := &log.RecordLogger{} - old := log.UseLogger(rl) - defer log.UseLogger(old) + defer log.UseLogger(rl)() err := Start( WithAPIKey("12345678901234567890123456789012"),