Skip to content

Commit

Permalink
Add tests for context management in DSL modules
Browse files Browse the repository at this point in the history
  • Loading branch information
xosmig committed Jun 23, 2022
1 parent de28cda commit ae18e0c
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 40 deletions.
266 changes: 228 additions & 38 deletions pkg/dsl/dslmodule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@ import (
"github.com/filecoin-project/mir/pkg/events"
"github.com/filecoin-project/mir/pkg/modules"
"github.com/filecoin-project/mir/pkg/pb/eventpb"
t "github.com/filecoin-project/mir/pkg/types"
"github.com/filecoin-project/mir/pkg/types"
"github.com/filecoin-project/mir/pkg/util/mathutil"
"github.com/filecoin-project/mir/pkg/util/sliceutil"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/wrapperspb"
"strconv"
"testing"
)

type moduleConfig struct {
Self t.ModuleID
Replies t.ModuleID
Reports t.ModuleID
type simpleModuleConfig struct {
Self types.ModuleID
Replies types.ModuleID
Reports types.ModuleID
}

func defaultModuleConfig() *moduleConfig {
return &moduleConfig{
func defaultSimpleModuleConfig() *simpleModuleConfig {
return &simpleModuleConfig{
Self: "testing",
Replies: "replies",
Reports: "reports",
}
}

func newTestingModule(mc *moduleConfig) modules.PassiveModule {
func newSimpleTestingModule(mc *simpleModuleConfig) modules.PassiveModule {
m := NewModule(mc.Self)

// state
Expand Down Expand Up @@ -100,7 +101,7 @@ func newTestingModule(mc *moduleConfig) modules.PassiveModule {
}

func TestDslModule_ApplyEvents(t *testing.T) {
mc := defaultModuleConfig()
mc := defaultSimpleModuleConfig()

tests := map[string]struct {
eventsIn *events.EventList
Expand All @@ -113,37 +114,42 @@ func TestDslModule_ApplyEvents(t *testing.T) {
err: nil,
},
"hello world": {
eventsIn: events.ListOf(TestingString(mc.Self, "hello")),
eventsOut: events.ListOf(TestingString(mc.Replies, "world"), TestingUint(mc.Replies, 42)),
eventsIn: events.ListOf(events.TestingString(mc.Self, "hello")),
eventsOut: events.ListOf(events.TestingString(mc.Replies, "world"), events.TestingUint(mc.Replies, 42)),
err: nil,
},
"test error": {
eventsIn: events.ListOf(TestingString(mc.Self, "good")),
eventsIn: events.ListOf(events.TestingString(mc.Self, "good")),
eventsOut: events.EmptyList(),
err: errors.New("bye"),
},
"test simple condition": {
eventsIn: events.ListOf(TestingString(mc.Self, "foo"), TestingString(mc.Self, "bar"),
TestingString(mc.Self, "baz"), TestingString(mc.Self, "quz")),
eventsOut: events.ListOf(TestingString(mc.Reports, "Collected at least 3 testing strings: [foo bar baz quz]")),
eventsIn: events.ListOf(
events.TestingString(mc.Self, "foo"), events.TestingString(mc.Self, "bar"),
events.TestingString(mc.Self, "baz"), events.TestingString(mc.Self, "quz")),
eventsOut: events.ListOf(
events.TestingString(mc.Reports, "Collected at least 3 testing strings: [foo bar baz quz]")),
},
"test multiple handlers for one event and a loop condition": {
eventsIn: events.ListOf(TestingUint(mc.Self, 0), TestingUint(mc.Self, 17), TestingUint(mc.Self, 105),
TestingUint(mc.Self, 182), TestingUint(mc.Self, 42), TestingUint(mc.Self, 222),
TestingUint(mc.Self, 14)),
eventsIn: events.ListOf(
events.TestingUint(mc.Self, 0), events.TestingUint(mc.Self, 17), events.TestingUint(mc.Self, 105),
events.TestingUint(mc.Self, 182), events.TestingUint(mc.Self, 42), events.TestingUint(mc.Self, 222),
events.TestingUint(mc.Self, 14)),
// if the number is below 100, the module will reply with a string representation of the number.
// the module will also add up all received values and will emit reports 100, 200, and so on if these
// thresholds are passed at the end of the batch. In this example, the total sum is 582.
eventsOut: events.ListOf(TestingString(mc.Replies, "0"), TestingString(mc.Replies, "17"),
TestingString(mc.Replies, "42"), TestingString(mc.Replies, "14"), TestingUint(mc.Reports, 100),
TestingUint(mc.Reports, 200), TestingUint(mc.Reports, 300), TestingUint(mc.Reports, 400),
TestingUint(mc.Reports, 500)),
eventsOut: events.ListOf(
events.TestingString(mc.Replies, "0"), events.TestingString(mc.Replies, "17"),
events.TestingString(mc.Replies, "42"), events.TestingString(mc.Replies, "14"),
events.TestingUint(mc.Reports, 100), events.TestingUint(mc.Reports, 200),
events.TestingUint(mc.Reports, 300), events.TestingUint(mc.Reports, 400),
events.TestingUint(mc.Reports, 500)),
},
}

for testName, tc := range tests {
t.Run(testName, func(t *testing.T) {
m := newTestingModule(mc)
m := newSimpleTestingModule(mc)
eventsOutList, err := m.ApplyEvents(tc.eventsIn)

if tc.err != nil {
Expand Down Expand Up @@ -177,34 +183,218 @@ func TestDslModule_ApplyEvents(t *testing.T) {
}
}

// protobuf wrappers (similar to the ones in pkg/events/events.pb)
type contextTestingModuleModuleConfig struct {
Self types.ModuleID
Crypto types.ModuleID
Hasher types.ModuleID
Timer types.ModuleID
Signed types.ModuleID
Hashed types.ModuleID
Verified types.ModuleID
}

func defaultContextTestingModuleConfig() *contextTestingModuleModuleConfig {
return &contextTestingModuleModuleConfig{
Self: "testing",
Crypto: "crypto",
Hasher: "hasher",
Timer: "timer",
Signed: "signed",
Hashed: "hashed",
Verified: "verified",
}
}

type testingStringContext struct {
s string
}

func newContextTestingModule(mc *contextTestingModuleModuleConfig) Module {
m := NewModule(mc.Self)

UponTestingString(m, func(s string) error {
SignRequest(m, mc.Crypto, [][]byte{[]byte(s)}, &testingStringContext{s})
HashOneMessage(m, mc.Hasher, [][]byte{[]byte(s)}, &testingStringContext{s})
return nil
})

UponSignResult(m, func(signature []byte, context *testingStringContext) error {
EmitTestingString(m, mc.Signed, fmt.Sprintf("%s: %s", context.s, string(signature)))
return nil
})

UponHashResult(m, func(hashes [][]byte, context *testingStringContext) error {
if len(hashes) != 1 {
return fmt.Errorf("unexpected number of hashes: %v", hashes)
}
EmitTestingString(m, mc.Hashed, fmt.Sprintf("%s: %s", context.s, string(hashes[0])))
return nil
})

UponTestingUint(m, func(u uint64) error {
if u < 10 {
msg := [][]byte{[]byte("uint"), []byte(strconv.FormatUint(u, 10))}

var signatures [][]byte
var nodeIDs []types.NodeID
for i := uint64(0); i < u; i++ {
signatures = append(signatures, []byte(strconv.FormatUint(i, 10)))
nodeIDs = append(nodeIDs, types.NodeID(strconv.FormatUint(i, 10)))
}

// NB: avoid using primitive types as the context in the actual implementation, prefer named structs,
// remember that the context type is used to match requests with responses.
VerifyNodeSigs(m, mc.Crypto, sliceutil.Repeat(msg, u), signatures, nodeIDs, &u)
}
return nil
})

UponOneNodeSigVerified(m, func(nodeID types.NodeID, err error, context *uint64) error {
if err == nil {
EmitTestingString(m, mc.Verified, fmt.Sprintf("%v: %v verified", *context, nodeID))
}
return nil
})

UponNodeSigsVerified(m, func(nodeIDs []types.NodeID, errs []error, allOK bool, context *uint64) error {
if allOK {
EmitTestingUint(m, mc.Verified, *context)
}
return nil
})

return m
}

func TestDslModule_ContextRecoveryAndCleanup(t *testing.T) {
testCases := map[string]func(mc *contextTestingModuleModuleConfig, m Module){
"empty": func(mc *contextTestingModuleModuleConfig, m Module) {},

func TestingString(dest t.ModuleID, s string) *eventpb.Event {
return &eventpb.Event{
DestModule: dest.Pb(),
Type: &eventpb.Event_TestingString{
TestingString: wrapperspb.String(s),
"request_response": func(mc *contextTestingModuleModuleConfig, m Module) {
eventsOut, err := m.ApplyEvents(events.ListOf(events.TestingString(mc.Self, "hello")))
assert.Nil(t, err)
assert.Equal(t, 2, eventsOut.Len())

iter := eventsOut.Iterator()
signOrigin := iter.Next().Type.(*eventpb.Event_SignRequest).SignRequest.Origin
hashOrigin := iter.Next().Type.(*eventpb.Event_HashRequest).HashRequest.Origin

eventsOut, err = m.ApplyEvents(events.ListOf(events.SignResult(mc.Self, []byte("world"), signOrigin)))
assert.Nil(t, err)
assert.Equal(t, []*eventpb.Event{events.TestingString(mc.Signed, "hello: world")}, eventsOut.Slice())

eventsOut, err = m.ApplyEvents(events.ListOf(events.HashResult(mc.Self, [][]byte{[]byte("world")}, hashOrigin)))
assert.Nil(t, err)
assert.Equal(t, []*eventpb.Event{events.TestingString(mc.Hashed, "hello: world")}, eventsOut.Slice())
},

"response_without_request": func(mc *contextTestingModuleModuleConfig, m Module) {
assert.Panics(t, func() {
// Context with id 42 doesn't exist. The module should panic.
_, _ = m.ApplyEvents(events.ListOf(
events.SignResult(mc.Self, []byte{}, DslSignOrigin(mc.Self, ContextID(42)))))
})
},

"check_context_is_disposed": func(mc *contextTestingModuleModuleConfig, m Module) {
eventsOut, err := m.ApplyEvents(events.ListOf(events.TestingString(mc.Self, "hello")))
assert.Nil(t, err)
assert.Equal(t, 2, eventsOut.Len())

iter := eventsOut.Iterator()
signOrigin := iter.Next().Type.(*eventpb.Event_SignRequest).SignRequest.Origin
_ = iter.Next().Type.(*eventpb.Event_HashRequest).HashRequest.Origin

eventsOut, err = m.ApplyEvents(events.ListOf(events.SignResult(mc.Self, []byte("world"), signOrigin)))
assert.Nil(t, err)
assert.Equal(t, []*eventpb.Event{events.TestingString(mc.Signed, "hello: world")}, eventsOut.Slice())

assert.Panics(t, func() {
// This reply is sent for the second time.
//The context should already be disposed of and the module should panic.
_, _ = m.ApplyEvents(events.ListOf(events.SignResult(mc.Self, []byte("world"), signOrigin)))
})
},

"check_multiple_handlers_for_response": func(mc *contextTestingModuleModuleConfig, m Module) {
eventsOut, err := m.ApplyEvents(events.ListOf(events.TestingUint(mc.Self, 8)))
assert.Nil(t, err)
assert.Equal(t, 1, eventsOut.Len())

iter := eventsOut.Iterator()
sigVerEvent := iter.Next().Type.(*eventpb.Event_VerifyNodeSigs).VerifyNodeSigs
sigVerNodes := sigVerEvent.NodeIds
assert.Equal(t, 8, len(sigVerNodes))
sigVerOrigin := sigVerEvent.Origin

// send some undelated events to make sure the context is preserved and does not get overwritten
_, err = m.ApplyEvents(events.ListOf(events.TestingString(mc.Self, "hello")))
assert.Nil(t, err)
_, err = m.ApplyEvents(events.ListOf(events.TestingUint(mc.Self, 3)))
assert.Nil(t, err)
_, err = m.ApplyEvents(events.ListOf(events.TestingUint(mc.Self, 16), events.TestingString(mc.Self, "foo")))
assert.Nil(t, err)

// construct a response for the signature verification request.
sigsVerifiedEvent := events.NodeSigsVerified(
/*destModule*/ mc.Self,
/*valid*/ sliceutil.Repeat(true, 8),
/*errors*/ sliceutil.Repeat("", 8),
/*nodeIDs*/ types.NodeIDSlice(sigVerNodes),
/*origin*/ sigVerOrigin,
/*allOk*/ true,
)

eventsOut, err = m.ApplyEvents(events.ListOf(sigsVerifiedEvent))
assert.Nil(t, err)

var expectedResponse []*eventpb.Event
for i := 0; i < 8; i++ {
expectedResponse = append(expectedResponse, events.TestingString(mc.Verified, fmt.Sprintf("8: %v verified", i)))
}
expectedResponse = append(expectedResponse, events.TestingUint(mc.Verified, 8))

assert.Equal(t, expectedResponse, eventsOut.Slice())

assert.Panics(t, func() {
// This reply is sent for the second time.
//The context should already be disposed of and the module should panic.
_, _ = m.ApplyEvents(events.ListOf(sigsVerifiedEvent))
})
},
}

for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
mc := defaultContextTestingModuleConfig()
m := newContextTestingModule(mc)
tc(mc, m)
})
}

}

func TestingUint(dest t.ModuleID, u uint64) *eventpb.Event {
return &eventpb.Event{
DestModule: dest.Pb(),
Type: &eventpb.Event_TestingUint{
TestingUint: wrapperspb.UInt64(u),
// event wrappers (similar to the ones in pkg/events/events.go)

func DslSignOrigin(module types.ModuleID, contextID ContextID) *eventpb.SignOrigin {
return &eventpb.SignOrigin{
Module: module.Pb(),
Type: &eventpb.SignOrigin_Dsl{
Dsl: &eventpb.DslOrigin{
ContextID: contextID.Pb(),
},
},
}
}

// dsl wrappers (similar to the ones in pkg/dsl/events.go)

func EmitTestingString(m Module, dest t.ModuleID, s string) {
EmitEvent(m, TestingString(dest, s))
func EmitTestingString(m Module, dest types.ModuleID, s string) {
EmitEvent(m, events.TestingString(dest, s))
}

func EmitTestingUint(m Module, dest t.ModuleID, u uint64) {
EmitEvent(m, TestingUint(dest, u))
func EmitTestingUint(m Module, dest types.ModuleID, u uint64) {
EmitEvent(m, events.TestingUint(dest, u))
}

func UponTestingString(m Module, handler func(s string) error) {
Expand Down
6 changes: 4 additions & 2 deletions pkg/util/sliceutil/sliceutil.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package sliceutil

import "golang.org/x/exp/constraints"

// Repeat returns a slice with value repeated n times.
func Repeat[T any](value T, n int) []T {
func Repeat[T any, N constraints.Integer](value T, n N) []T {
arr := make([]T, n)
for i := 0; i < n; i++ {
for i := N(0); i < n; i++ {
arr[i] = value
}
return arr
Expand Down

0 comments on commit ae18e0c

Please sign in to comment.