diff --git a/merkledag.go b/merkledag.go index c035dd4..568d445 100644 --- a/merkledag.go +++ b/merkledag.go @@ -162,7 +162,7 @@ func FetchGraph(ctx context.Context, root cid.Cid, serv ipld.DAGService) error { } // FetchGraphWithDepthLimit fetches all nodes that are children to the given -// node down to the given depth. maxDetph=0 means "only fetch root", +// node down to the given depth. maxDepth=0 means "only fetch root", // maxDepth=1 means "fetch root and its direct children" and so on... // maxDepth=-1 means unlimited. func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, serv ipld.DAGService) error { @@ -195,9 +195,10 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s return false } + // If we have a ProgressTracker, we wrap the visit function to handle it v, _ := ctx.Value(progressContextKey).(*ProgressTracker) if v == nil { - return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visit) + return WalkDepth(ctx, GetLinksDirect(ng), root, visit, Concurrent(), WithRoot()) } visitProgress := func(c cid.Cid, depth int) bool { @@ -207,7 +208,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s } return false } - return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress) + return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, Concurrent(), WithRoot()) } // GetMany gets many nodes from the DAG at once. @@ -281,30 +282,143 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks { } } +// defaultConcurrentFetch is the default maximum number of concurrent fetches +// that 'fetchNodes' will start at a time +const defaultConcurrentFetch = 32 + +// walkOptions represent the parameters of a graph walking algorithm +type walkOptions struct { + WithRoot bool + Concurrency int + ErrorHandler func(c cid.Cid, err error) error +} + +// WalkOption is a setter for walkOptions +type WalkOption func(*walkOptions) + +func (wo *walkOptions) addHandler(handler func(c cid.Cid, err error) error) { + if wo.ErrorHandler != nil { + wo.ErrorHandler = func(c cid.Cid, err error) error { + return handler(c, wo.ErrorHandler(c, err)) + } + } else { + wo.ErrorHandler = handler + } +} + +// WithRoot is a WalkOption indicating that the root node should be visited +func WithRoot() WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.WithRoot = true + } +} + +// Concurrent is a WalkOption indicating that node fetching should be done in +// parallel, with the default concurrency factor. +// NOTE: When using that option, the walk order is *not* guarantee. +// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. +func Concurrent() WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.Concurrency = defaultConcurrentFetch + } +} + +// Concurrency is a WalkOption indicating that node fetching should be done in +// parallel, with a specific concurrency factor. +// NOTE: When using that option, the walk order is *not* guarantee. +// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. +func Concurrency(worker int) WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.Concurrency = worker + } +} + +// IgnoreErrors is a WalkOption indicating that the walk should attempt to +// continue even when an error occur. +func IgnoreErrors() WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.addHandler(func(c cid.Cid, err error) error { + return nil + }) + } +} + +// IgnoreMissing is a WalkOption indicating that the walk should continue when +// a node is missing. +func IgnoreMissing() WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.addHandler(func(c cid.Cid, err error) error { + if err == ipld.ErrNotFound { + return nil + } + return err + }) + } +} + +// OnMissing is a WalkOption adding a callback that will be triggered on a missing +// node. +func OnMissing(callback func(c cid.Cid)) WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.addHandler(func(c cid.Cid, err error) error { + if err == ipld.ErrNotFound { + callback(c) + } + return err + }) + } +} + +// OnError is a WalkOption adding a custom error handler. +// If this handler return a nil error, the walk will continue. +func OnError(handler func(c cid.Cid, err error) error) WalkOption { + return func(walkOptions *walkOptions) { + walkOptions.addHandler(handler) + } +} + // WalkGraph will walk the dag in order (depth first) starting at the given root. -func Walk(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error { +func Walk(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool, options ...WalkOption) error { visitDepth := func(c cid.Cid, depth int) bool { return visit(c) } - return WalkDepth(ctx, getLinks, root, 0, visitDepth) + return WalkDepth(ctx, getLinks, c, visitDepth, options...) } // WalkDepth walks the dag starting at the given root and passes the current // depth to a given visit function. The visit function can be used to limit DAG // exploration. -func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error { - if !visit(root, depth) { - return nil +func WalkDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid, int) bool, options ...WalkOption) error { + opts := &walkOptions{} + for _, opt := range options { + opt(opts) + } + + if opts.Concurrency > 1 { + return parallelWalkDepth(ctx, getLinks, c, visit, opts) + } else { + return sequentialWalkDepth(ctx, getLinks, c, 0, visit, opts) + } +} + +func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool, options *walkOptions) error { + if depth != 0 || options.WithRoot { + if !visit(root, depth) { + return nil + } } links, err := getLinks(ctx, root) + if err != nil && options.ErrorHandler != nil { + err = options.ErrorHandler(root, err) + } if err != nil { return err } for _, lnk := range links { - if err := WalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit); err != nil { + if err := sequentialWalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit, options); err != nil { return err } } @@ -337,27 +451,7 @@ func (p *ProgressTracker) Value() int { return p.Total } -// FetchGraphConcurrency is total number of concurrent fetches that -// 'fetchNodes' will start at a time -var FetchGraphConcurrency = 32 - -// WalkParallel is equivalent to Walk *except* that it explores multiple paths -// in parallel. -// -// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. -func WalkParallel(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error { - visitDepth := func(c cid.Cid, depth int) bool { - return visit(c) - } - - return WalkParallelDepth(ctx, getLinks, c, 0, visitDepth) -} - -// WalkParallelDepth is equivalent to WalkDepth *except* that it fetches -// children in parallel. -// -// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. -func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error { +func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *walkOptions) error { type cidDepth struct { cid cid.Cid depth int @@ -372,14 +466,14 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD out := make(chan *linksDepth) done := make(chan struct{}) - var setlk sync.Mutex + var visitlk sync.Mutex var wg sync.WaitGroup errChan := make(chan error) fetchersCtx, cancel := context.WithCancel(ctx) defer wg.Wait() defer cancel() - for i := 0; i < FetchGraphConcurrency; i++ { + for i := 0; i < options.Concurrency; i++ { wg.Add(1) go func() { defer wg.Done() @@ -387,12 +481,22 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD ci := cdepth.cid depth := cdepth.depth - setlk.Lock() - shouldVisit := visit(ci, depth) - setlk.Unlock() + var shouldVisit bool + + // bypass the root if needed + if depth != 0 || options.WithRoot { + visitlk.Lock() + shouldVisit = visit(ci, depth) + visitlk.Unlock() + } else { + shouldVisit = true + } if shouldVisit { links, err := getLinks(ctx, ci) + if err != nil && options.ErrorHandler != nil { + err = options.ErrorHandler(root, err) + } if err != nil { select { case errChan <- err: @@ -422,20 +526,21 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD defer close(feed) send := feed - var todobuffer []*cidDepth + var todoQueue []*cidDepth var inProgress int next := &cidDepth{ - cid: c, - depth: startDepth, + cid: root, + depth: 0, } + for { select { case send <- next: inProgress++ - if len(todobuffer) > 0 { - next = todobuffer[0] - todobuffer = todobuffer[1:] + if len(todoQueue) > 0 { + next = todoQueue[0] + todoQueue = todoQueue[1:] } else { next = nil send = nil @@ -456,7 +561,7 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD next = cd send = feed } else { - todobuffer = append(todobuffer, cd) + todoQueue = append(todoQueue, cd) } } case err := <-errChan: @@ -466,7 +571,6 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD return ctx.Err() } } - } var _ ipld.LinkGetter = &dagService{} diff --git a/merkledag_test.go b/merkledag_test.go index e56bb52..045a880 100644 --- a/merkledag_test.go +++ b/merkledag_test.go @@ -203,8 +203,11 @@ func makeTestDAG(t *testing.T, read io.Reader, ds ipld.DAGService) ipld.Node { // Add a root referencing all created nodes root := NodeWithData(nil) for _, n := range nodes { - root.AddNodeLink(n.Cid().String(), n) - err := ds.Add(ctx, n) + err := root.AddNodeLink(n.Cid().String(), n) + if err != nil { + t.Fatal(err) + } + err = ds.Add(ctx, n) if err != nil { t.Fatal(err) } @@ -383,7 +386,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) { } - err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF) + err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), visitF, WithRoot()) if err != nil { t.Fatal(err) } @@ -736,7 +739,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) { } cset := cid.NewSet() - err = WalkParallel(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit) + err = Walk(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit) if err == nil { t.Fatal("this should have failed") }