Skip to content

Commit

Permalink
use a single generic as suggested by @Wondertan
Browse files Browse the repository at this point in the history
  • Loading branch information
Stebalien committed Aug 3, 2024
1 parent 410665a commit d55f695
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 212 deletions.
19 changes: 10 additions & 9 deletions cbor_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var _ = xerrors.Errorf

var lengthBufNode = []byte{130}

func (t *Node[V, T]) MarshalCBOR(w io.Writer) error {
func (t *Node[T]) MarshalCBOR(w io.Writer) error {
if t == nil {
_, err := w.Write(cbg.CborNull)
return err
Expand Down Expand Up @@ -61,8 +61,8 @@ func (t *Node[V, T]) MarshalCBOR(w io.Writer) error {
return nil
}

func (t *Node[V, T]) UnmarshalCBOR(r io.Reader) (err error) {
*t = Node[V, T]{}
func (t *Node[T]) UnmarshalCBOR(r io.Reader) (err error) {
*t = Node[T]{}

cr := cbg.NewCborReader(r)

Expand Down Expand Up @@ -124,12 +124,12 @@ func (t *Node[V, T]) UnmarshalCBOR(r io.Reader) (err error) {
}

if extra > 0 {
t.Pointers = make([]*Pointer[V, T], extra)
t.Pointers = make([]*Pointer[T], extra)
}

for i := 0; i < int(extra); i++ {

var v Pointer[V, T]
var v Pointer[T]
if err := v.UnmarshalCBOR(cr); err != nil {
return err
}
Expand All @@ -142,7 +142,7 @@ func (t *Node[V, T]) UnmarshalCBOR(r io.Reader) (err error) {

var lengthBufKV = []byte{130}

func (t *KV[V, T]) MarshalCBOR(w io.Writer) error {
func (t *KV[T]) MarshalCBOR(w io.Writer) error {
if t == nil {
_, err := w.Write(cbg.CborNull)
return err
Expand Down Expand Up @@ -174,8 +174,8 @@ func (t *KV[V, T]) MarshalCBOR(w io.Writer) error {
return nil
}

func (t *KV[V, T]) UnmarshalCBOR(r io.Reader) (err error) {
*t = KV[V, T]{}
func (t *KV[T]) UnmarshalCBOR(r io.Reader) (err error) {
*t = KV[T]{}

cr := cbg.NewCborReader(r)

Expand Down Expand Up @@ -223,7 +223,8 @@ func (t *KV[V, T]) UnmarshalCBOR(r io.Reader) (err error) {

{

var value T = new(V)
var value T
value = value.New()
if err := value.UnmarshalCBOR(cr); err != nil {
return xerrors.Errorf("failed to read field: %w", err)
}
Expand Down
56 changes: 28 additions & 28 deletions diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,30 @@ const (

// Change represents a change to a DAG and contains a reference to the old and
// new CIDs.
type Change[V any] struct {
type Change[T any] struct {
Type ChangeType
Key string
Before *V
After *V
Before T
After T
}

func (ch Change[V]) String() string {
func (ch Change[T]) String() string {
b, _ := json.Marshal(ch)
return string(b)
}

// Diff returns a set of changes that transform node 'prev' into node 'cur'. opts are applied to both prev and cur.
func Diff[V any, T HamtValue[V]](ctx context.Context, prevBs, curBs cbor.IpldStore, prev, cur cid.Cid, opts ...Option) ([]*Change[V], error) {
func Diff[T HamtValue[T]](ctx context.Context, prevBs, curBs cbor.IpldStore, prev, cur cid.Cid, opts ...Option) ([]*Change[T], error) {
if prev.Equals(cur) {
return nil, nil
}

prevHamt, err := LoadNode[V, T](ctx, prevBs, prev, opts...)
prevHamt, err := LoadNode[T](ctx, prevBs, prev, opts...)
if err != nil {
return nil, err
}

curHamt, err := LoadNode[V, T](ctx, curBs, cur, opts...)
curHamt, err := LoadNode[T](ctx, curBs, cur, opts...)
if err != nil {
return nil, err
}
Expand All @@ -55,7 +55,7 @@ func Diff[V any, T HamtValue[V]](ctx context.Context, prevBs, curBs cbor.IpldSto
return diffNode(ctx, prevHamt, curHamt, 0)
}

func diffNode[V any, T HamtValue[V]](ctx context.Context, pre, cur *Node[V, T], depth int) ([]*Change[V], error) {
func diffNode[T HamtValue[T]](ctx context.Context, pre, cur *Node[T], depth int) ([]*Change[T], error) {
// which Bitfield contains the most bits. We will start a loop from this index, calling Bitfield.Bit(idx)
// on an out of range index will return zero.
bp := cur.Bitfield.BitLen()
Expand All @@ -64,7 +64,7 @@ func diffNode[V any, T HamtValue[V]](ctx context.Context, pre, cur *Node[V, T],
}

// the changes between cur and prev
var changes []*Change[V]
var changes []*Change[T]

// loop over each bit in the bitfields
for idx := bp; idx >= 0; idx-- {
Expand Down Expand Up @@ -136,11 +136,11 @@ func diffNode[V any, T HamtValue[V]](ctx context.Context, pre, cur *Node[V, T],
changes = append(changes, rm...)
} else {
for _, p := range pointer.KVs {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Remove,
Key: string(p.Key),
Before: p.Value,
After: nil,
After: zero[T](),
})
}
}
Expand All @@ -160,10 +160,10 @@ func diffNode[V any, T HamtValue[V]](ctx context.Context, pre, cur *Node[V, T],
changes = append(changes, add...)
} else {
for _, p := range pointer.KVs {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Add,
Key: string(p.Key),
Before: nil,
Before: zero[T](),
After: p.Value,
})
}
Expand All @@ -174,10 +174,10 @@ func diffNode[V any, T HamtValue[V]](ctx context.Context, pre, cur *Node[V, T],
return changes, nil
}

func diffKVs[V any, T HamtValue[V]](pre, cur []*KV[V, T], idx int) []*Change[V] {
func diffKVs[T HamtValue[T]](pre, cur []*KV[T], idx int) []*Change[T] {
preMap := make(map[string]T, len(pre))
curMap := make(map[string]T, len(cur))
var changes []*Change[V]
var changes []*Change[T]

for _, kv := range pre {
preMap[string(kv.Key)] = kv.Value
Expand All @@ -188,27 +188,27 @@ func diffKVs[V any, T HamtValue[V]](pre, cur []*KV[V, T], idx int) []*Change[V]
// find removed keys: keys in pre and not in cur
for key, value := range preMap {
if _, ok := curMap[key]; !ok {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Remove,
Key: key,
Before: value,
After: nil,
After: zero[T](),
})
}
}
// find added keys: keys in cur and not in pre
// find modified values: keys in cur and pre with different values
for key, curVal := range curMap {
if preVal, ok := preMap[key]; !ok {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Add,
Key: key,
Before: nil,
Before: zero[T](),
After: curVal,
})
} else {
if !preVal.Equal(curVal) {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Modify,
Key: key,
Before: preVal,
Expand All @@ -220,13 +220,13 @@ func diffKVs[V any, T HamtValue[V]](pre, cur []*KV[V, T], idx int) []*Change[V]
return changes
}

func addAll[V any, T HamtValue[V]](ctx context.Context, node *Node[V, T], idx int) ([]*Change[V], error) {
var changes []*Change[V]
func addAll[T HamtValue[T]](ctx context.Context, node *Node[T], idx int) ([]*Change[T], error) {
var changes []*Change[T]
if err := node.ForEach(ctx, func(k string, val T) error {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Add,
Key: k,
Before: nil,
Before: zero[T](),
After: val,
})

Expand All @@ -237,14 +237,14 @@ func addAll[V any, T HamtValue[V]](ctx context.Context, node *Node[V, T], idx in
return changes, nil
}

func removeAll[V any, T HamtValue[V]](ctx context.Context, node *Node[V, T], idx int) ([]*Change[V], error) {
var changes []*Change[V]
func removeAll[T HamtValue[T]](ctx context.Context, node *Node[T], idx int) ([]*Change[T], error) {
var changes []*Change[T]
if err := node.ForEach(ctx, func(k string, val T) error {
changes = append(changes, &Change[V]{
changes = append(changes, &Change[T]{
Type: Remove,
Key: k,
Before: val,
After: nil,
After: zero[T](),
})

return nil
Expand Down
Loading

0 comments on commit d55f695

Please sign in to comment.