Skip to content

Commit

Permalink
Merge pull request #2203 from coryb/issue-2112
Browse files Browse the repository at this point in the history
[#2112] progress.Controller should own the progress.Writer to prevent leaks
  • Loading branch information
coryb authored Jun 28, 2021
2 parents 8d33bbd + b1d441b commit f4fcba5
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 63 deletions.
2 changes: 1 addition & 1 deletion cache/remotecache/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
type ResolveCacheExporterFunc func(ctx context.Context, g session.Group, attrs map[string]string) (Exporter, error)

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
2 changes: 1 addition & 1 deletion exporter/containerimage/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ func getRefMetadata(ref cache.ImmutableRef, limit int) []refMetadata {
}

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
2 changes: 1 addition & 1 deletion exporter/local/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (e *localExporterInstance) Export(ctx context.Context, inp exporter.Source,

func newProgressHandler(ctx context.Context, id string) func(int, bool) {
limiter := rate.NewLimiter(rate.Every(100*time.Millisecond), 1)
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
2 changes: 1 addition & 1 deletion exporter/oci/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (e *imageExporterInstance) Export(ctx context.Context, src exporter.Source,
}

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
2 changes: 1 addition & 1 deletion exporter/tar/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (e *localExporterInstance) Export(ctx context.Context, inp exporter.Source,
}

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
6 changes: 3 additions & 3 deletions solver/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ func (jl *Solver) NewJob(id string) (*Job, error) {
}

pr, ctx, progressCloser := progress.NewContext(context.Background())
pw, _, _ := progress.FromContext(ctx) // TODO: expose progress.Pipe()
pw, _, _ := progress.NewFromContext(ctx) // TODO: expose progress.Pipe()

_, span := trace.NewNoopTracerProvider().Tracer("").Start(ctx, "")
j := &Job{
Expand Down Expand Up @@ -881,7 +881,7 @@ func (v *vertexWithCacheOptions) Inputs() []Edge {
}

func notifyStarted(ctx context.Context, v *client.Vertex, cached bool) {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
defer pw.Close()
now := time.Now()
v.Started = &now
Expand All @@ -891,7 +891,7 @@ func notifyStarted(ctx context.Context, v *client.Vertex, cached bool) {
}

func notifyCompleted(ctx context.Context, v *client.Vertex, err error, cached bool) {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
defer pw.Close()
now := time.Now()
if v.Started == nil {
Expand Down
8 changes: 4 additions & 4 deletions solver/llbsolver/solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ func allWorkers(wc *worker.Controller) func(func(w worker.Worker) error) error {
}

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand All @@ -352,7 +352,7 @@ func inBuilderContext(ctx context.Context, b solver.Builder, name, id string, f
Name: name,
}
return b.InContext(ctx, func(ctx context.Context, g session.Group) error {
pw, _, ctx := progress.FromContext(ctx, progress.WithMetadata("vertex", v.Digest))
pw, _, ctx := progress.NewFromContext(ctx, progress.WithMetadata("vertex", v.Digest))
notifyStarted(ctx, &v, false)
defer pw.Close()
err := f(ctx, g)
Expand All @@ -362,7 +362,7 @@ func inBuilderContext(ctx context.Context, b solver.Builder, name, id string, f
}

func notifyStarted(ctx context.Context, v *client.Vertex, cached bool) {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
defer pw.Close()
now := time.Now()
v.Started = &now
Expand All @@ -372,7 +372,7 @@ func notifyStarted(ctx context.Context, v *client.Vertex, cached bool) {
}

func notifyCompleted(ctx context.Context, v *client.Vertex, err error, cached bool) {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
defer pw.Close()
now := time.Now()
if v.Started == nil {
Expand Down
5 changes: 2 additions & 3 deletions source/containerimage/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,8 @@ func (p *puller) CacheKey(ctx context.Context, g session.Group, index int) (cach
}

if len(p.manifest.Descriptors) > 0 {
pw, _, _ := progress.FromContext(ctx)
progressController := &controller.Controller{
Writer: pw,
WriterFactory: progress.FromContext(ctx),
}
if p.vtx != nil {
progressController.Digest = p.vtx.Digest()
Expand Down Expand Up @@ -369,7 +368,7 @@ func cacheKeyFromConfig(dt []byte) digest.Digest {
}

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
2 changes: 1 addition & 1 deletion source/local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func (ls *localSourceHandler) snapshot(ctx context.Context, s session.Group, cal

func newProgressHandler(ctx context.Context, id string) func(int, bool) {
limiter := rate.NewLimiter(rate.Every(100*time.Millisecond), 1)
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
4 changes: 2 additions & 2 deletions util/flightcontrol/flightcontrol.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
<-c.cleaned
return nil, errRetry
}
pw, ok, _ := progress.FromContext(ctx)
pw, ok, _ := progress.NewFromContext(ctx)
if ok {
c.progressState.add(pw)
}
Expand All @@ -149,7 +149,7 @@ func (c *call) wait(ctx context.Context) (v interface{}, err error) {
default:
}

pw, ok, ctx := progress.FromContext(ctx)
pw, ok, ctx := progress.NewFromContext(ctx)
if ok {
c.progressState.add(pw)
}
Expand Down
48 changes: 26 additions & 22 deletions util/progress/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,57 +13,61 @@ import (
type Controller struct {
count int64
started *time.Time
writer progress.Writer

Digest digest.Digest
Name string
Writer progress.Writer
Digest digest.Digest
Name string
WriterFactory progress.WriterFactory
}

var _ progress.Controller = &Controller{}

func (c *Controller) Start(ctx context.Context) (context.Context, func(error)) {
if c.Digest == "" {
return progress.WithProgress(ctx, c.Writer), func(error) {}
}

if atomic.AddInt64(&c.count, 1) == 1 {
if c.started == nil {
now := time.Now()
c.started = &now
c.writer, _, _ = c.WriterFactory(ctx)
}

if c.Digest != "" {
c.writer.Write(c.Digest.String(), client.Vertex{
Digest: c.Digest,
Name: c.Name,
Started: c.started,
})
}
c.Writer.Write(c.Digest.String(), client.Vertex{
Digest: c.Digest,
Name: c.Name,
Started: c.started,
})
}
return progress.WithProgress(ctx, c.Writer), func(err error) {
return progress.WithProgress(ctx, c.writer), func(err error) {
if atomic.AddInt64(&c.count, -1) == 0 {
now := time.Now()
var errString string
if err != nil {
errString = err.Error()
}
c.Writer.Write(c.Digest.String(), client.Vertex{
Digest: c.Digest,
Name: c.Name,
Started: c.started,
Completed: &now,
Error: errString,
})
if c.Digest != "" {
c.writer.Write(c.Digest.String(), client.Vertex{
Digest: c.Digest,
Name: c.Name,
Started: c.started,
Completed: &now,
Error: errString,
})
}
c.writer.Close()
}
}
}

func (c *Controller) Status(id string, action string) func() {
start := time.Now()
c.Writer.Write(id, progress.Status{
c.writer.Write(id, progress.Status{
Action: action,
Started: &start,
})
return func() {
complete := time.Now()
c.Writer.Write(id, progress.Status{
c.writer.Write(id, progress.Status{
Action: action,
Started: &start,
Completed: &complete,
Expand Down
4 changes: 2 additions & 2 deletions util/progress/logs/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewLogStreams(ctx context.Context, printOutput bool) (io.WriteCloser, io.Wr
}

func newStreamWriter(ctx context.Context, stream int, printOutput bool) io.WriteCloser {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
return &streamWriter{
pw: pw,
stream: stream,
Expand Down Expand Up @@ -132,7 +132,7 @@ func (sw *streamWriter) Close() error {

func LoggerFromContext(ctx context.Context) func([]byte) {
return func(dt []byte) {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
defer pw.Close()
pw.Write(identity.NewID(), client.VertexLog{
Stream: stderr,
Expand Down
2 changes: 1 addition & 1 deletion util/progress/multireader.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (mr *MultiReader) Reader(ctx context.Context) Reader {
defer mr.mu.Unlock()

pr, ctx, closeWriter := NewContext(ctx)
pw, _, ctx := FromContext(ctx)
pw, _, ctx := NewFromContext(ctx)

w := pw.(*progressWriter)
mr.writers[w] = closeWriter
Expand Down
41 changes: 28 additions & 13 deletions util/progress/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,37 @@ type contextKeyT string

var contextKey = contextKeyT("buildkit/util/progress")

// FromContext returns a progress writer from a context.
func FromContext(ctx context.Context, opts ...WriterOption) (Writer, bool, context.Context) {
// WriterFactory will generate a new progress Writer and return a new Context
// with the new Writer stored. It is the callers responsibility to Close the
// returned Writer to avoid resource leaks.
type WriterFactory func(ctx context.Context) (Writer, bool, context.Context)

// FromContext returns a WriterFactory to generate new progress writers based
// on a Writer previously stored in the Context.
func FromContext(ctx context.Context, opts ...WriterOption) WriterFactory {
v := ctx.Value(contextKey)
pw, ok := v.(*progressWriter)
if !ok {
if pw, ok := v.(*MultiWriter); ok {
return pw, true, ctx
return func(ctx context.Context) (Writer, bool, context.Context) {
pw, ok := v.(*progressWriter)
if !ok {
if pw, ok := v.(*MultiWriter); ok {
return pw, true, ctx
}
return &noOpWriter{}, false, ctx
}
return &noOpWriter{}, false, ctx
}
pw = newWriter(pw)
for _, o := range opts {
o(pw)
pw = newWriter(pw)
for _, o := range opts {
o(pw)
}
ctx = context.WithValue(ctx, contextKey, pw)
return pw, true, ctx
}
ctx = context.WithValue(ctx, contextKey, pw)
return pw, true, ctx
}

// NewFromContext creates a new Writer based on a Writer previously stored
// in the Context and returns a new Context with the new Writer stored. It is
// the callers responsibility to Close the returned Writer to avoid resource leaks.
func NewFromContext(ctx context.Context, opts ...WriterOption) (Writer, bool, context.Context) {
return FromContext(ctx, opts...)(ctx)
}

type WriterOption func(Writer)
Expand Down
6 changes: 3 additions & 3 deletions util/progress/progress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestProgress(t *testing.T) {
return saveProgress(ctx, pr, &trace)
})

pw, _, ctx := FromContext(ctx, WithMetadata("tag", "foo"))
pw, _, ctx := NewFromContext(ctx, WithMetadata("tag", "foo"))
s, err = calc(ctx, 5, "calc")
pw.Close()
assert.NoError(t, err)
Expand Down Expand Up @@ -66,7 +66,7 @@ func TestProgressNested(t *testing.T) {
}

func calc(ctx context.Context, total int, name string) (int, error) {
pw, _, ctx := FromContext(ctx)
pw, _, ctx := NewFromContext(ctx)
defer pw.Close()

sum := 0
Expand All @@ -91,7 +91,7 @@ func calc(ctx context.Context, total int, name string) (int, error) {
func reduceCalc(ctx context.Context, total int) (int, error) {
eg, ctx := errgroup.WithContext(ctx)

pw, _, ctx := FromContext(ctx)
pw, _, ctx := NewFromContext(ctx)
defer pw.Close()

pw.Write("reduce", Status{Action: "starting"})
Expand Down
2 changes: 1 addition & 1 deletion util/pull/pullprogress/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func trackProgress(ctx context.Context, desc ocispec.Descriptor, manager PullMan
ticker.Stop()
}()

pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
defer pw.Close()

ingestRef := remotes.MakeRefKey(ctx, desc)
Expand Down
2 changes: 1 addition & 1 deletion util/push/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func annotateDistributionSourceHandler(manager content.Manager, annotations map[
}

func oneOffProgress(ctx context.Context, id string) func(err error) error {
pw, _, _ := progress.FromContext(ctx)
pw, _, _ := progress.NewFromContext(ctx)
now := time.Now()
st := progress.Status{
Started: &now,
Expand Down
5 changes: 3 additions & 2 deletions worker/base/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,11 @@ func (w *Worker) Exporter(name string, sm *session.Manager) (exporter.Exporter,
}

func (w *Worker) FromRemote(ctx context.Context, remote *solver.Remote) (ref cache.ImmutableRef, err error) {
pw, _, _ := progress.FromContext(ctx)
descHandler := &cache.DescHandler{
Provider: func(session.Group) content.Provider { return remote.Provider },
Progress: &controller.Controller{Writer: pw},
Progress: &controller.Controller{
WriterFactory: progress.FromContext(ctx),
},
}
descHandlers := cache.DescHandlers(make(map[digest.Digest]*cache.DescHandler))
for _, desc := range remote.Descriptors {
Expand Down

0 comments on commit f4fcba5

Please sign in to comment.