diff --git a/collections/indexes/doc.go b/collections/indexes/doc.go new file mode 100644 index 000000000000..aa4b584cc7b8 --- /dev/null +++ b/collections/indexes/doc.go @@ -0,0 +1,3 @@ +// Package indexes contains the most common indexes types to be used with a collections.IndexedMap. +// It also contains specialised helper functions to collect and query efficiently an index. +package indexes diff --git a/collections/indexes/helpers.go b/collections/indexes/helpers.go new file mode 100644 index 000000000000..666f6cc6c9f7 --- /dev/null +++ b/collections/indexes/helpers.go @@ -0,0 +1,110 @@ +package indexes + +import ( + "context" + "cosmossdk.io/collections" +) + +// Iterator defines the minimum set of methods of an index iterator +// required to work with the helpers. +type Iterator[K any] interface { + // PrimaryKey returns the iterator current primary key. + PrimaryKey() (K, error) + // Next advances the iterator by one element. + Next() + // Valid asserts if the Iterator is valid. + Valid() bool + // Close closes the iterator. + Close() error +} + +// CollectKeyValues collects all the keys and the values of an indexed map index iterator. +// The Iterator is fully consumed and closed. +func CollectKeyValues[K, V any, I Iterator[K], Idx collections.Indexes[K, V]]( + ctx context.Context, + indexedMap *collections.IndexedMap[K, V, Idx], + iter I) (kvs []collections.KeyValue[K, V], err error) { + err = ScanKeyValues(ctx, indexedMap, iter, func(kv collections.KeyValue[K, V]) bool { + kvs = append(kvs, kv) + return false + }) + return +} + +// ScanKeyValues calls the do function on every record found, in the indexed map +// from the index iterator. Returning false stops the iteration. +// The Iterator is closed when this function exits. +func ScanKeyValues[K, V any, I Iterator[K], Idx collections.Indexes[K, V]]( + ctx context.Context, + indexedMap *collections.IndexedMap[K, V, Idx], + iter I, + do func(kv collections.KeyValue[K, V]) (stop bool)) (err error) { + + defer iter.Close() + + for ; iter.Valid(); iter.Next() { + pk, err := iter.PrimaryKey() + if err != nil { + return err + } + + value, err := indexedMap.Get(ctx, pk) + if err != nil { + return err + } + + kv := collections.KeyValue[K, V]{ + Key: pk, + Value: value, + } + + if do(kv) { + break + } + } + + return nil +} + +// CollectValues collects all the values from an Index iterator and the IndexedMap. +// Closes the Iterator. +func CollectValues[K, V any, I Iterator[K], Idx collections.Indexes[K, V]]( + ctx context.Context, + indexedMap *collections.IndexedMap[K, V, Idx], + iter I) (values []V, err error) { + err = ScanValues(ctx, indexedMap, iter, func(value V) (stop bool) { + values = append(values, value) + return false + }) + return +} + +// ScanValues collects all the values from an Index iterator and the IndexedMap in a lazy way. +// The iterator is closed when this function exits. +func ScanValues[K, V any, I Iterator[K], Idx collections.Indexes[K, V]]( + ctx context.Context, + indexedMap *collections.IndexedMap[K, V, Idx], + iter I, + f func(value V) (stop bool), +) error { + defer iter.Close() + + for ; iter.Valid(); iter.Next() { + key, err := iter.PrimaryKey() + if err != nil { + return err + } + + value, err := indexedMap.Get(ctx, key) + if err != nil { + return err + } + + stop := f(value) + if stop { + return nil + } + } + + return nil +} diff --git a/collections/indexes/helpers_test.go b/collections/indexes/helpers_test.go new file mode 100644 index 000000000000..691bec55ad2f --- /dev/null +++ b/collections/indexes/helpers_test.go @@ -0,0 +1,87 @@ +package indexes + +import ( + "cosmossdk.io/collections" + "github.com/stretchr/testify/require" + "testing" +) + +func TestHelpers(t *testing.T) { + // uses MultiPair scenario. + // We store balances as: + // Key: Pair[Address=string, Denom=string] => Value: Amount=uint64 + + sk, ctx := deps() + sb := collections.NewSchemaBuilder(sk) + + keyCodec := collections.PairKeyCodec(collections.StringKey, collections.StringKey) + indexedMap := collections.NewIndexedMap( + sb, + collections.NewPrefix("balances"), "balances", + keyCodec, + collections.Uint64Value, + balanceIndex{ + Denom: NewMultiPair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec), + }, + ) + + err := indexedMap.Set(ctx, collections.Join("address1", "atom"), 100) + require.NoError(t, err) + + err = indexedMap.Set(ctx, collections.Join("address1", "osmo"), 200) + require.NoError(t, err) + + err = indexedMap.Set(ctx, collections.Join("address2", "osmo"), 300) + require.NoError(t, err) + + // test collect values + iter, err := indexedMap.Indexes.Denom.MatchExact(ctx, "osmo") + require.NoError(t, err) + + values, err := CollectValues(ctx, indexedMap, iter) + require.NoError(t, err) + require.Equal(t, []Amount{200, 300}, values) + + // test collect key values + + iter, err = indexedMap.Indexes.Denom.MatchExact(ctx, "osmo") + require.NoError(t, err) + kvs, err := CollectKeyValues(ctx, indexedMap, iter) + require.NoError(t, err) + + require.Equal(t, []collections.KeyValue[collections.Pair[Address, Denom], Amount]{ + { + Key: collections.Join("address1", "osmo"), + Value: 200, + }, + { + Key: collections.Join("address2", "osmo"), + Value: 300, + }, + }, kvs) + + // test scan values with early termination + iter, err = indexedMap.Indexes.Denom.MatchExact(ctx, "osmo") + require.NoError(t, err) + numCalled := 0 + err = ScanValues(ctx, indexedMap, iter, func(v Amount) bool { + require.Equal(t, Amount(200), v) + numCalled++ + require.Equal(t, numCalled, 1) + return true // says to stop + }) + require.NoError(t, err) + + // test scan kv with early termination + iter, err = indexedMap.Indexes.Denom.MatchExact(ctx, "osmo") + require.NoError(t, err) + numCalled = 0 + err = ScanKeyValues(ctx, indexedMap, iter, func(kv collections.KeyValue[collections.Pair[Address, Denom], Amount]) bool { + require.Equal(t, Amount(200), kv.Value) + require.Equal(t, collections.Join("address1", "osmo"), kv.Key) + numCalled++ + require.Equal(t, numCalled, 1) + return true // says to stop + }) + require.NoError(t, err) +} diff --git a/collections/indexes/indexes_test.go b/collections/indexes/indexes_test.go new file mode 100644 index 000000000000..be65add9adf5 --- /dev/null +++ b/collections/indexes/indexes_test.go @@ -0,0 +1,53 @@ +package indexes + +import ( + "context" + "cosmossdk.io/core/store" + db "github.com/cosmos/cosmos-db" +) + +// TODO remove this when we add testStore to core/store. + +type testStore struct { + db db.DB +} + +func (t testStore) OpenKVStore(ctx context.Context) store.KVStore { + return t +} + +func (t testStore) Get(key []byte) ([]byte, error) { + return t.db.Get(key) +} + +func (t testStore) Has(key []byte) (bool, error) { + return t.db.Has(key) +} + +func (t testStore) Set(key, value []byte) error { + return t.db.Set(key, value) +} + +func (t testStore) Delete(key []byte) error { + return t.db.Delete(key) +} + +func (t testStore) Iterator(start, end []byte) (store.Iterator, error) { + return t.db.Iterator(start, end) +} + +func (t testStore) ReverseIterator(start, end []byte) (store.Iterator, error) { + return t.db.ReverseIterator(start, end) +} + +var _ store.KVStore = testStore{} + +func deps() (store.KVStoreService, context.Context) { + kv := db.NewMemDB() + return &testStore{kv}, context.Background() +} + +type company struct { + City string + Vat uint64 +} diff --git a/collections/indexes/multi.go b/collections/indexes/multi.go new file mode 100644 index 000000000000..24cb07534ee8 --- /dev/null +++ b/collections/indexes/multi.go @@ -0,0 +1,104 @@ +package indexes + +import ( + "context" + "cosmossdk.io/collections" +) + +// Multi defines the most common index. It can be used to create a reference between +// a field of value and its primary key. Multiple primary keys can be mapped to the same +// reference key as the index does not enforce uniqueness constraints. +type Multi[ReferenceKey, PrimaryKey, Value any] collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value] + +// NewMulti instantiates a new Multi instance given a schema, +// a Prefix, the humanized name for the index, the reference key key codec +// and the primary key key codec. The getRefKeyFunc is a function that +// given the primary key and value returns the referencing key. +func NewMulti[ReferenceKey, PrimaryKey, Value any]( + schema *collections.SchemaBuilder, + prefix collections.Prefix, + name string, + refCodec collections.KeyCodec[ReferenceKey], + pkCodec collections.KeyCodec[PrimaryKey], + getRefKeyFunc func(pk PrimaryKey, value Value) (ReferenceKey, error), +) *Multi[ReferenceKey, PrimaryKey, Value] { + i := collections.NewGenericMultiIndex( + schema, prefix, name, refCodec, pkCodec, + func(pk PrimaryKey, value Value) ([]collections.IndexReference[ReferenceKey, PrimaryKey], error) { + ref, err := getRefKeyFunc(pk, value) + if err != nil { + return nil, err + } + return []collections.IndexReference[ReferenceKey, PrimaryKey]{ + collections.NewIndexReference(ref, pk), + }, nil + }, + ) + + return (*Multi[ReferenceKey, PrimaryKey, Value])(i) +} + +func (m *Multi[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error { + return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Reference(ctx, pk, newValue, oldValue) +} + +func (m *Multi[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, value Value) error { + return (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Unreference(ctx, pk, value) +} + +func (m *Multi[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[ReferenceKey, PrimaryKey]]) (MultiIterator[ReferenceKey, PrimaryKey], error) { + iter, err := (*collections.GenericMultiIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(m).Iterate(ctx, ranger) + return (MultiIterator[ReferenceKey, PrimaryKey])(iter), err +} + +// MatchExact returns a MultiIterator containing all the primary keys referenced by the provided reference key. +func (m *Multi[ReferenceKey, PrimaryKey, Value]) MatchExact(ctx context.Context, refKey ReferenceKey) (MultiIterator[ReferenceKey, PrimaryKey], error) { + return m.Iterate(ctx, collections.NewPrefixedPairRange[ReferenceKey, PrimaryKey](refKey)) +} + +// MultiIterator is just a KeySetIterator with key as Pair[ReferenceKey, PrimaryKey]. +type MultiIterator[ReferenceKey, PrimaryKey any] collections.KeySetIterator[collections.Pair[ReferenceKey, PrimaryKey]] + +// PrimaryKey returns the iterator's current primary key. +func (i MultiIterator[ReferenceKey, PrimaryKey]) PrimaryKey() (PrimaryKey, error) { + fullKey, err := i.FullKey() + return fullKey.K2(), err +} + +// PrimaryKeys fully consumes the iterator and returns the list of primary keys. +func (i MultiIterator[ReferenceKey, PrimaryKey]) PrimaryKeys() ([]PrimaryKey, error) { + fullKeys, err := i.FullKeys() + if err != nil { + return nil, err + } + pks := make([]PrimaryKey, len(fullKeys)) + for i, fullKey := range fullKeys { + pks[i] = fullKey.K2() + } + return pks, nil +} + +// FullKey returns the current full reference key as Pair[ReferenceKey, PrimaryKey]. +func (i MultiIterator[ReferenceKey, PrimaryKey]) FullKey() (collections.Pair[ReferenceKey, PrimaryKey], error) { + return (collections.KeySetIterator[collections.Pair[ReferenceKey, PrimaryKey]])(i).Key() +} + +// FullKeys fully consumes the iterator and returns all the list of full reference keys. +func (i MultiIterator[ReferenceKey, PrimaryKey]) FullKeys() ([]collections.Pair[ReferenceKey, PrimaryKey], error) { + return (collections.KeySetIterator[collections.Pair[ReferenceKey, PrimaryKey]])(i).Keys() +} + +// Next advances the iterator. +func (i MultiIterator[ReferenceKey, PrimaryKey]) Next() { + (collections.KeySetIterator[collections.Pair[ReferenceKey, PrimaryKey]])(i).Next() +} + +// Valid asserts if the iterator is still valid or not. +func (i MultiIterator[ReferenceKey, PrimaryKey]) Valid() bool { + return (collections.KeySetIterator[collections.Pair[ReferenceKey, PrimaryKey]])(i).Valid() +} + +// Close closes the iterator. +func (i MultiIterator[ReferenceKey, PrimaryKey]) Close() error { + return (collections.KeySetIterator[collections.Pair[ReferenceKey, PrimaryKey]])(i).Close() +} diff --git a/collections/indexes/multi_pair.go b/collections/indexes/multi_pair.go new file mode 100644 index 000000000000..819730b3842a --- /dev/null +++ b/collections/indexes/multi_pair.go @@ -0,0 +1,110 @@ +package indexes + +import ( + "context" + "cosmossdk.io/collections" +) + +// MultiPair is an index that is used with collections.Pair keys. It indexes objects by their second part of the key. +// When the value is being indexed by collections.IndexedMap then MultiPair will create a relationship between +// the second part of the primary key and the first part. +type MultiPair[K1, K2, Value any] collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value] + +// TODO(tip): this is an interface to cast a collections.KeyCodec +// to a pair codec. currently we return it as a KeyCodec[Pair[K1, K2]] +// to improve dev experience with type inference, which means we cannot +// get the concrete implementation which exposes KeyCodec1 and KeyCodec2. +type pairKeyCodec[K1, K2 any] interface { + KeyCodec1() collections.KeyCodec[K1] + KeyCodec2() collections.KeyCodec[K2] +} + +// NewMultiPair instantiates a new MultiPair index. +// NOTE: when using this function you will need to type hint: doing NewMultiPair[Value]() +// Example: if the value of the indexed map is string, you need to do NewMultiPair[string](...) +func NewMultiPair[Value any, K1, K2 any]( + sb *collections.SchemaBuilder, + prefix collections.Prefix, + name string, + pairCodec collections.KeyCodec[collections.Pair[K1, K2]], +) *MultiPair[K1, K2, Value] { + pkc := pairCodec.(pairKeyCodec[K1, K2]) + mi := collections.NewGenericMultiIndex( + sb, + prefix, + name, + pkc.KeyCodec2(), + pkc.KeyCodec1(), + func(pk collections.Pair[K1, K2], _ Value) ([]collections.IndexReference[K2, K1], error) { + return []collections.IndexReference[K2, K1]{ + collections.NewIndexReference(pk.K2(), pk.K1()), + }, nil + }, + ) + + return (*MultiPair[K1, K2, Value])(mi) +} + +// Iterate exposes the raw iterator API. +func (i *MultiPair[K1, K2, Value]) Iterate(ctx context.Context, ranger collections.Ranger[collections.Pair[K2, K1]]) (iter MultiPairIterator[K2, K1], err error) { + sIter, err := (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Iterate(ctx, ranger) + if err != nil { + return iter, err + } + return (MultiPairIterator[K2, K1])(sIter), nil +} + +// MatchExact will return an iterator containing only the primary keys starting with the provided second part of the multipart pair key. +func (i *MultiPair[K1, K2, Value]) MatchExact(ctx context.Context, key K2) (MultiPairIterator[K2, K1], error) { + return i.Iterate(ctx, collections.NewPrefixedPairRange[K2, K1](key)) +} + +// Reference implements collections.Index +func (i *MultiPair[K1, K2, Value]) Reference(ctx context.Context, pk collections.Pair[K1, K2], value Value, oldValue *Value) error { + return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Reference(ctx, pk, value, oldValue) +} + +// Unreference implements collections.Index +func (i *MultiPair[K1, K2, Value]) Unreference(ctx context.Context, pk collections.Pair[K1, K2], value Value) error { + return (*collections.GenericMultiIndex[K2, K1, collections.Pair[K1, K2], Value])(i).Unreference(ctx, pk, value) +} + +// MultiPairIterator is a helper type around a collections.KeySetIterator when used to work +// with MultiPair indexes iterations. +type MultiPairIterator[K2, K1 any] collections.KeySetIterator[collections.Pair[K2, K1]] + +// PrimaryKey returns the primary key from the index. The index is composed like a reverse +// pair key. So we just fetch the pair key from the index and return the reverse. +func (m MultiPairIterator[K2, K1]) PrimaryKey() (pair collections.Pair[K1, K2], err error) { + reversePair, err := (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Key() + if err != nil { + return pair, err + } + pair = collections.Join(reversePair.K2(), reversePair.K1()) + return pair, nil +} + +// PrimaryKeys returns all the primary keys contained in the iterator. +func (m MultiPairIterator[K2, K1]) PrimaryKeys() (pairs []collections.Pair[K1, K2], err error) { + defer m.Close() + for ; m.Valid(); m.Next() { + pair, err := m.PrimaryKey() + if err != nil { + return nil, err + } + pairs = append(pairs, pair) + } + return pairs, err +} + +func (m MultiPairIterator[K2, K1]) Next() { + (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Next() +} + +func (m MultiPairIterator[K2, K1]) Valid() bool { + return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Valid() +} + +func (m MultiPairIterator[K2, K1]) Close() error { + return (collections.KeySetIterator[collections.Pair[K2, K1]])(m).Close() +} diff --git a/collections/indexes/multi_pair_test.go b/collections/indexes/multi_pair_test.go new file mode 100644 index 000000000000..3e88f04f7bab --- /dev/null +++ b/collections/indexes/multi_pair_test.go @@ -0,0 +1,59 @@ +package indexes + +import ( + "cosmossdk.io/collections" + "github.com/stretchr/testify/require" + "testing" +) + +type Address = string +type Denom = string +type Amount = uint64 + +// our balance index, allows us to efficiently create an index between the key that maps +// balances which is a collections.Pair[Address, Denom] and the Denom. +type balanceIndex struct { + Denom *MultiPair[Address, Denom, Amount] +} + +func (b balanceIndex) IndexesList() []collections.Index[collections.Pair[Address, Denom], Amount] { + return []collections.Index[collections.Pair[Address, Denom], Amount]{b.Denom} +} + +func TestMultiPair(t *testing.T) { + sk, ctx := deps() + sb := collections.NewSchemaBuilder(sk) + // we create an indexed map that maps balances, which are saved as + // key: Pair[Address, Denom] + // value: Amount + keyCodec := collections.PairKeyCodec(collections.StringKey, collections.StringKey) + + indexedMap := collections.NewIndexedMap( + sb, + collections.NewPrefix("balances"), "balances", + keyCodec, + collections.Uint64Value, + balanceIndex{ + Denom: NewMultiPair[Amount](sb, collections.NewPrefix("denom_index"), "denom_index", keyCodec), + }, + ) + + err := indexedMap.Set(ctx, collections.Join("address1", "atom"), 100) + require.NoError(t, err) + + err = indexedMap.Set(ctx, collections.Join("address1", "osmo"), 200) + require.NoError(t, err) + + err = indexedMap.Set(ctx, collections.Join("address2", "osmo"), 300) + require.NoError(t, err) + + // assert if we iterate over osmo we find address1 and address2 + iter, err := indexedMap.Indexes.Denom.MatchExact(ctx, "osmo") + require.NoError(t, err) + defer iter.Close() + + pks, err := iter.PrimaryKeys() + require.NoError(t, err) + require.Equal(t, "address1", pks[0].K1()) + require.Equal(t, "address2", pks[1].K1()) +} diff --git a/collections/indexes/multi_test.go b/collections/indexes/multi_test.go new file mode 100644 index 000000000000..8c2ffa0ffbb5 --- /dev/null +++ b/collections/indexes/multi_test.go @@ -0,0 +1,61 @@ +package indexes + +import ( + "cosmossdk.io/collections" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMultiIndex(t *testing.T) { + sk, ctx := deps() + schema := collections.NewSchemaBuilder(sk) + + mi := NewMulti(schema, collections.NewPrefix(1), "multi_index", collections.StringKey, collections.Uint64Key, func(_ uint64, value company) (string, error) { + return value.City, nil + }) + + // we crete two reference keys for primary key 1 and 2 associated with "milan" + require.NoError(t, mi.Reference(ctx, 1, company{City: "milan"}, nil)) + require.NoError(t, mi.Reference(ctx, 2, company{City: "milan"}, nil)) + + iter, err := mi.MatchExact(ctx, "milan") + require.NoError(t, err) + pks, err := iter.PrimaryKeys() + require.NoError(t, err) + require.Equal(t, []uint64{1, 2}, pks) + + // replace + require.NoError(t, mi.Reference(ctx, 1, company{City: "new york"}, &company{City: "milan"})) + + // assert after replace only company with id 2 is referenced by milan + iter, err = mi.MatchExact(ctx, "milan") + require.NoError(t, err) + pks, err = iter.PrimaryKeys() + require.NoError(t, err) + require.Equal(t, []uint64{2}, pks) + + // assert after replace company with id 1 is referenced by new york + iter, err = mi.MatchExact(ctx, "new york") + require.NoError(t, err) + pks, err = iter.PrimaryKeys() + require.NoError(t, err) + require.Equal(t, []uint64{1}, pks) + + // test iter methods + iter, err = mi.Iterate(ctx, nil) + require.NoError(t, err) + + fullKey, err := iter.FullKey() + require.NoError(t, err) + require.Equal(t, collections.Join("milan", uint64(2)), fullKey) + + pk, err := iter.PrimaryKey() + require.NoError(t, err) + require.Equal(t, uint64(2), pk) + + iter.Next() + require.True(t, iter.Valid()) + iter.Next() + require.False(t, iter.Valid()) + require.NoError(t, iter.Close()) +} diff --git a/collections/indexes/unique.go b/collections/indexes/unique.go new file mode 100644 index 000000000000..4385f2d11e4f --- /dev/null +++ b/collections/indexes/unique.go @@ -0,0 +1,92 @@ +package indexes + +import ( + "context" + "cosmossdk.io/collections" +) + +// Unique identifies an index that imposes uniqueness constraints on the reference key. +// It creates relationships between reference and primary key of the value. +type Unique[ReferenceKey, PrimaryKey, Value any] collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value] + +// NewUnique instantiates a new Unique index. +func NewUnique[ReferenceKey, PrimaryKey, Value any]( + schema *collections.SchemaBuilder, + prefix collections.Prefix, + name string, + refCodec collections.KeyCodec[ReferenceKey], + pkCodec collections.KeyCodec[PrimaryKey], + getRefKeyFunc func(pk PrimaryKey, v Value) (ReferenceKey, error), +) *Unique[ReferenceKey, PrimaryKey, Value] { + i := collections.NewGenericUniqueIndex(schema, prefix, name, refCodec, pkCodec, func(pk PrimaryKey, value Value) ([]collections.IndexReference[ReferenceKey, PrimaryKey], error) { + ref, err := getRefKeyFunc(pk, value) + if err != nil { + return nil, err + } + + return []collections.IndexReference[ReferenceKey, PrimaryKey]{ + collections.NewIndexReference(ref, pk), + }, nil + }) + + return (*Unique[ReferenceKey, PrimaryKey, Value])(i) +} + +func (i *Unique[ReferenceKey, PrimaryKey, Value]) Reference(ctx context.Context, pk PrimaryKey, newValue Value, oldValue *Value) error { + return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Reference(ctx, pk, newValue, oldValue) +} + +func (i *Unique[ReferenceKey, PrimaryKey, Value]) Unreference(ctx context.Context, pk PrimaryKey, value Value) error { + return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Unreference(ctx, pk, value) +} + +func (i *Unique[ReferenceKey, PrimaryKey, Value]) MatchExact(ctx context.Context, ref ReferenceKey) (PrimaryKey, error) { + return (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Get(ctx, ref) +} + +func (i *Unique[ReferenceKey, PrimaryKey, Value]) Iterate(ctx context.Context, ranger collections.Ranger[ReferenceKey]) (UniqueIterator[ReferenceKey, PrimaryKey], error) { + iter, err := (*collections.GenericUniqueIndex[ReferenceKey, PrimaryKey, PrimaryKey, Value])(i).Iterate(ctx, ranger) + return (UniqueIterator[ReferenceKey, PrimaryKey])(iter), err +} + +// UniqueIterator is an Iterator wrapper, that exposes only the functionality needed to work with Unique keys. +type UniqueIterator[ReferenceKey, PrimaryKey any] collections.Iterator[ReferenceKey, PrimaryKey] + +// PrimaryKey returns the iterator's current primary key. +func (i UniqueIterator[ReferenceKey, PrimaryKey]) PrimaryKey() (PrimaryKey, error) { + return (collections.Iterator[ReferenceKey, PrimaryKey])(i).Value() +} + +// PrimaryKeys fully consumes the iterator, and returns all the primary keys. +func (i UniqueIterator[ReferenceKey, PrimaryKey]) PrimaryKeys() ([]PrimaryKey, error) { + return (collections.Iterator[ReferenceKey, PrimaryKey])(i).Values() +} + +// FullKey returns the iterator's current full reference key as Pair[ReferenceKey, PrimaryKey]. +func (i UniqueIterator[ReferenceKey, PrimaryKey]) FullKey() (collections.Pair[ReferenceKey, PrimaryKey], error) { + kv, err := (collections.Iterator[ReferenceKey, PrimaryKey])(i).KeyValue() + return collections.Join(kv.Key, kv.Value), err +} + +func (i UniqueIterator[ReferenceKey, PrimaryKey]) FullKeys() ([]collections.Pair[ReferenceKey, PrimaryKey], error) { + kvs, err := (collections.Iterator[ReferenceKey, PrimaryKey])(i).KeyValues() + if err != nil { + return nil, err + } + pairKeys := make([]collections.Pair[ReferenceKey, PrimaryKey], len(kvs)) + for index := range kvs { + kv := kvs[index] + pairKeys[index] = collections.Join(kv.Key, kv.Value) + } + return pairKeys, nil +} + +func (i UniqueIterator[ReferenceKey, PrimaryKey]) Next() { + (collections.Iterator[ReferenceKey, PrimaryKey])(i).Next() +} +func (i UniqueIterator[ReferenceKey, PrimaryKey]) Valid() bool { + return (collections.Iterator[ReferenceKey, PrimaryKey])(i).Valid() +} +func (i UniqueIterator[ReferenceKey, PrimaryKey]) Close() error { + return (collections.Iterator[ReferenceKey, PrimaryKey])(i).Close() +} diff --git a/collections/indexes/unique_test.go b/collections/indexes/unique_test.go new file mode 100644 index 000000000000..7e9be931033d --- /dev/null +++ b/collections/indexes/unique_test.go @@ -0,0 +1,45 @@ +package indexes + +import ( + "cosmossdk.io/collections" + "github.com/stretchr/testify/require" + "testing" +) + +func TestUniqueIndex(t *testing.T) { + sk, ctx := deps() + schema := collections.NewSchemaBuilder(sk) + ui := NewUnique(schema, collections.NewPrefix("unique_index"), "unique_index", collections.Uint64Key, collections.Uint64Key, func(_ uint64, v company) (uint64, error) { + return v.Vat, nil + }) + + // map company with id 1 to vat 1_1 + err := ui.Reference(ctx, 1, company{Vat: 1_1}, nil) + require.NoError(t, err) + + // map company with id 2 to vat 2_2 + err = ui.Reference(ctx, 2, company{Vat: 2_2}, nil) + require.NoError(t, err) + + // mapping company 3 with vat 1_1 must yield to a ErrConflict + err = ui.Reference(ctx, 1, company{Vat: 1_1}, nil) + require.ErrorIs(t, err, collections.ErrConflict) + + // assert references are correct + id, err := ui.MatchExact(ctx, 1_1) + require.NoError(t, err) + require.Equal(t, uint64(1), id) + + id, err = ui.MatchExact(ctx, 2_2) + require.NoError(t, err) + require.Equal(t, uint64(2), id) + + // on reference updates, the new referencing key is created and the old is removed + err = ui.Reference(ctx, 1, company{Vat: 1_2}, &company{Vat: 1_1}) + require.NoError(t, err) + id, err = ui.MatchExact(ctx, 1_2) // assert a new reference is created + require.NoError(t, err) + require.Equal(t, uint64(1), id) + _, err = ui.MatchExact(ctx, 1_1) // assert old reference was removed + require.ErrorIs(t, err, collections.ErrNotFound) +} diff --git a/collections/pair.go b/collections/pair.go index 4bcd6637e828..5866875a5198 100644 --- a/collections/pair.go +++ b/collections/pair.go @@ -57,6 +57,10 @@ type pairKeyCodec[K1, K2 any] struct { keyCodec2 KeyCodec[K2] } +func (p pairKeyCodec[K1, K2]) KeyCodec1() KeyCodec[K1] { return p.keyCodec1 } + +func (p pairKeyCodec[K1, K2]) KeyCodec2() KeyCodec[K2] { return p.keyCodec2 } + func (p pairKeyCodec[K1, K2]) Encode(buffer []byte, pair Pair[K1, K2]) (int, error) { writtenTotal := 0 if pair.key1 != nil {