Skip to content

Commit

Permalink
feat: remove block decoding global registry
Browse files Browse the repository at this point in the history
  • Loading branch information
aschmahmann committed Jun 6, 2023
1 parent e379dec commit 7719ef4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 32 deletions.
62 changes: 35 additions & 27 deletions coding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,53 @@ package format

import (
"fmt"
"sync"

blocks "github.com/ipfs/go-block-format"
)

// DecodeBlockFunc functions decode blocks into nodes.
type DecodeBlockFunc func(block blocks.Block) (Node, error)

type BlockDecoder interface {
Register(codec uint64, decoder DecodeBlockFunc)
Decode(blocks.Block) (Node, error)
}
type safeBlockDecoder struct {
// Can be replaced with an RCU if necessary.
lock sync.RWMutex
// Registry is a structure for storing mappings of multicodec IPLD codec numbers to DecodeBlockFunc functions.
//
// Registry includes no mutexing. If using Registry in a concurrent context, you must handle synchronization yourself.
// (Typically, it is recommended to do initialization earlier in a program, before fanning out goroutines;
// this avoids the need for mutexing overhead.)
//
// Multicodec indicator numbers are specified in
// https://github.com/multiformats/multicodec/blob/master/table.csv .
// You should not use indicator numbers which are not specified in that table
// (however, there is nothing in this implementation that will attempt to stop you, either).
type Registry struct {
decoders map[uint64]DecodeBlockFunc
}

func (r *Registry) ensureInit() {
if r.decoders != nil {
return
}
r.decoders = make(map[uint64]DecodeBlockFunc)
}

// Register registers decoder for all blocks with the passed codec.
//
// This will silently replace any existing registered block decoders.
func (d *safeBlockDecoder) Register(codec uint64, decoder DecodeBlockFunc) {
d.lock.Lock()
defer d.lock.Unlock()
d.decoders[codec] = decoder
func (r *Registry) Register(codec uint64, decoder DecodeBlockFunc) {
r.ensureInit()
if decoder == nil {
panic("not sensible to attempt to register a nil function")
}
r.decoders[codec] = decoder
}

func (d *safeBlockDecoder) Decode(block blocks.Block) (Node, error) {
func (r *Registry) Decode(block blocks.Block) (Node, error) {
// Short-circuit by cast if we already have a Node.
if node, ok := block.(Node); ok {
return node, nil
}

ty := block.Cid().Type()

d.lock.RLock()
decoder, ok := d.decoders[ty]
d.lock.RUnlock()
r.ensureInit()
decoder, ok := r.decoders[ty]

if ok {
return decoder(block)
Expand All @@ -49,14 +58,13 @@ func (d *safeBlockDecoder) Decode(block blocks.Block) (Node, error) {
}
}

var DefaultBlockDecoder BlockDecoder = &safeBlockDecoder{decoders: make(map[uint64]DecodeBlockFunc)}

// Decode decodes the given block using the default BlockDecoder.
func Decode(block blocks.Block) (Node, error) {
return DefaultBlockDecoder.Decode(block)
}
// Decode decodes the given block using passed DecodeBlockFunc.
// Note: this is just a helper function, consider using the DecodeBlockFunc itself rather than this helper
func Decode(block blocks.Block, decoder DecodeBlockFunc) (Node, error) {
// Short-circuit by cast if we already have a Node.
if node, ok := block.(Node); ok {
return node, nil
}

// Register registers block decoders with the default BlockDecoder.
func Register(codec uint64, decoder DecodeBlockFunc) {
DefaultBlockDecoder.Register(codec, decoder)
return decoder(block)
}
54 changes: 49 additions & 5 deletions coding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,53 @@ import (
mh "github.com/multiformats/go-multihash"
)

func init() {
Register(cid.Raw, func(b blocks.Block) (Node, error) {
func TestDecode(t *testing.T) {
decoder := func(b blocks.Block) (Node, error) {
node := &EmptyNode{}
if b.RawData() != nil || !b.Cid().Equals(node.Cid()) {
return nil, errors.New("can only decode empty blocks")
}
return node, nil
})
}

id, err := cid.Prefix{
Version: 1,
Codec: cid.Raw,
MhType: mh.ID,
MhLength: 0,
}.Sum(nil)

if err != nil {
t.Fatalf("failed to create cid: %s", err)
}

block, err := blocks.NewBlockWithCid(nil, id)
if err != nil {
t.Fatalf("failed to create empty block: %s", err)
}
node, err := Decode(block, decoder)
if err != nil {
t.Fatalf("failed to decode empty node: %s", err)
}
if !node.Cid().Equals(id) {
t.Fatalf("empty node doesn't have the right cid")
}

if _, ok := node.(*EmptyNode); !ok {
t.Fatalf("empty node doesn't have the right type")
}

}

func TestDecode(t *testing.T) {
func TestRegistryDecode(t *testing.T) {
decoder := func(b blocks.Block) (Node, error) {
node := &EmptyNode{}
if b.RawData() != nil || !b.Cid().Equals(node.Cid()) {
return nil, errors.New("can only decode empty blocks")
}
return node, nil
}

id, err := cid.Prefix{
Version: 1,
Codec: cid.Raw,
Expand All @@ -35,10 +71,18 @@ func TestDecode(t *testing.T) {
if err != nil {
t.Fatalf("failed to create empty block: %s", err)
}
node, err := Decode(block)

reg := Registry{}
_, err = reg.Decode(block)
if err == nil || err.Error() != "unrecognized object type: 85" {
t.Fatalf("expected error, got %v", err)
}
reg.Register(cid.Raw, decoder)
node, err := reg.Decode(block)
if err != nil {
t.Fatalf("failed to decode empty node: %s", err)
}

if !node.Cid().Equals(id) {
t.Fatalf("empty node doesn't have the right cid")
}
Expand Down

0 comments on commit 7719ef4

Please sign in to comment.