From b3fced38daabe4b53c4313e358022a893bde5da9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 13 Dec 2024 11:27:09 +0000 Subject: [PATCH] Performance and bug-fix backports from Harmony This contains various backports: * Reduced allocations and performance improvements in state resolution * Reduced allocations and performance improvements in JSON handling * Event auth fixes, including correct error surfacing --- authstate.go | 14 +++-- backfill.go | 3 +- eventV1.go | 3 +- eventauth.go | 8 ++- eventauth_test.go | 14 ++--- handleinvite_test.go | 8 ++- handlejoin.go | 7 ++- handlejoin_test.go | 1 + handleleave.go | 7 ++- json.go | 63 +++++++------------ performinvite.go | 4 +- performjoin_test.go | 3 +- stateresolutionv2.go | 129 ++++++++++++++++++++++---------------- stateresolutionv2_test.go | 11 ++-- stateresolutionv2heaps.go | 81 +++++++++--------------- 15 files changed, 180 insertions(+), 176 deletions(-) diff --git a/authstate.go b/authstate.go index 046ed804..2344442d 100644 --- a/authstate.go +++ b/authstate.go @@ -165,7 +165,10 @@ func checkAllowedByAuthEvents( event PDU, eventsByID map[string]PDU, missingAuth EventProvider, userIDForSender spec.UserIDForSender, ) error { - authEvents := NewAuthEvents(nil) + authEvents, err := NewAuthEvents(nil) + if err != nil { + return err + } for _, ae := range event.AuthEventIDs() { retryEvent: @@ -214,7 +217,7 @@ func checkAllowedByAuthEvents( // If we made it this far then we've successfully got as many of the auth events as // as described by AuthEventIDs(). Check if they allow the event. - if err := Allowed(event, &authEvents, userIDForSender); err != nil { + if err := Allowed(event, authEvents, userIDForSender); err != nil { return fmt.Errorf( "gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s", event.EventID(), err.Error(), @@ -335,7 +338,10 @@ func CheckSendJoinResponse( } eventsByID := map[string]PDU{} - authEventProvider := NewAuthEvents(nil) + authEventProvider, err := NewAuthEvents(nil) + if err != nil { + return nil, err + } // Since checkAllowedByAuthEvents needs to be able to look up any of the // auth events by ID only, we will build a map which contains references @@ -369,7 +375,7 @@ func CheckSendJoinResponse( } // Now check that the join event is valid against the supplied state. - if err := Allowed(joinEvent, &authEventProvider, userIDForSender); err != nil { + if err := Allowed(joinEvent, authEventProvider, userIDForSender); err != nil { return nil, fmt.Errorf( "gomatrixserverlib: event with ID %q is not allowed by the current room state: %w", joinEvent.EventID(), err, diff --git a/backfill.go b/backfill.go index 23c5f87e..d2d2a02a 100644 --- a/backfill.go +++ b/backfill.go @@ -99,7 +99,8 @@ func RequestBackfill(ctx context.Context, origin spec.ServerName, b BackfillRequ } } - return result, lastErr + // Since we pulled in results from multiple servers we need to sort again... + return ReverseTopologicalOrdering(result, TopologicalOrderByPrevEvents), lastErr } /* diff --git a/eventV1.go b/eventV1.go index c01fd851..18f2cef0 100644 --- a/eventV1.go +++ b/eventV1.go @@ -215,8 +215,7 @@ func (e *eventV1) SetUnsignedField(path string, value interface{}) error { eventJSON = CanonicalJSONAssumeValid(eventJSON) res := gjson.GetBytes(eventJSON, "unsigned") - unsigned := RawJSONFromResult(res, eventJSON) - e.eventFields.Unsigned = unsigned + e.eventFields.Unsigned = []byte(res.Raw) e.eventJSON = eventJSON diff --git a/eventauth.go b/eventauth.go index 39a41d7d..8397fffb 100644 --- a/eventauth.go +++ b/eventauth.go @@ -301,15 +301,17 @@ func (a *AuthEvents) Clear() { // NewAuthEvents returns an AuthEventProvider backed by the given events. New events can be added by // calling AddEvent(). -func NewAuthEvents(events []PDU) AuthEvents { +func NewAuthEvents(events []PDU) (*AuthEvents, error) { a := AuthEvents{ events: make(map[StateKeyTuple]PDU, len(events)), roomIDs: make(map[string]struct{}), } for _, e := range events { - a.AddEvent(e) // nolint: errcheck + if err := a.AddEvent(e); err != nil { + return nil, err + } } - return a + return &a, nil } // A NotAllowed error is returned if an event does not pass the auth checks. diff --git a/eventauth_test.go b/eventauth_test.go index f9040aee..428f6a69 100644 --- a/eventauth_test.go +++ b/eventauth_test.go @@ -104,7 +104,7 @@ func TestStateNeededForMessage(t *testing.T) { // Message events need the create event, the sender and the power_levels. testStateNeededForAuth(t, `[{ "type": "m.room.message", - "sender": "@u1:a", + "sender": "@u1:a", "room_id": "!r1:a" }]`, &ProtoEvent{ Type: "m.room.message", @@ -139,7 +139,7 @@ func TestStateNeededForJoin(t *testing.T) { "type": "m.room.member", "state_key": "@u1:a", "sender": "@u1:a", - "content": {"membership": "join"}, + "content": {"membership": "join"}, "room_id": "!r1:a" }]`, &b, StateNeeded{ Create: true, @@ -163,7 +163,7 @@ func TestStateNeededForInvite(t *testing.T) { "type": "m.room.member", "state_key": "@u2:b", "sender": "@u1:a", - "content": {"membership": "invite"}, + "content": {"membership": "invite"}, "room_id": "!r1:a" }]`, &b, StateNeeded{ Create: true, @@ -199,7 +199,7 @@ func TestStateNeededForInvite3PID(t *testing.T) { "token": "my_token" } } - }, + }, "room_id": "!r1:a" }]`, &b, StateNeeded{ Create: true, @@ -1035,7 +1035,7 @@ func TestAuthEvents(t *testing.T) { if err != nil { t.Fatalf("TestAuthEvents: failed to create power_levels event: %s", err) } - a := NewAuthEvents([]PDU{power}) + a, _ := NewAuthEvents([]PDU{power}) var e PDU if e, err = a.PowerLevels(); err != nil || e != power { t.Errorf("TestAuthEvents: failed to get same power_levels event") @@ -1685,7 +1685,7 @@ func TestMembershipBanned(t *testing.T) { "state_key": "@u3:a", "event_id": "$e4:a", "content": {"membership": "ban"} - }, + }, { "type": "m.room.member", "sender": "@u2:a", @@ -1693,7 +1693,7 @@ func TestMembershipBanned(t *testing.T) { "state_key": "@u3:a", "event_id": "$e4:a", "content": {"membership": "ban"} - }, + }, { "type": "m.room.member", "sender": "@u2:a", diff --git a/handleinvite_test.go b/handleinvite_test.go index aab1935b..0255f020 100644 --- a/handleinvite_test.go +++ b/handleinvite_test.go @@ -38,9 +38,11 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv return nil, fmt.Errorf("failed getting auth provider") } - eventProvider := AuthEvents{} + eventProvider, _ := NewAuthEvents(nil) if r.createEvent != nil { - eventProvider = NewAuthEvents([]PDU{r.createEvent}) + if err := eventProvider.AddEvent(r.createEvent); err != nil { + return nil, err + } if r.inviterMemberEvent != nil { err := eventProvider.AddEvent(r.inviterMemberEvent) if err != nil { @@ -48,7 +50,7 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv } } } - return &eventProvider, nil + return eventProvider, nil } func (r *TestStateQuerier) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []StateKeyTuple) ([]PDU, error) { diff --git a/handlejoin.go b/handlejoin.go index 56234995..4d8de783 100644 --- a/handlejoin.go +++ b/handlejoin.go @@ -116,8 +116,11 @@ func HandleMakeJoin(input HandleMakeJoinInput) (*HandleMakeJoinResponse, error) return nil, spec.InternalServerError{Err: fmt.Sprintf("expected join event from template builder. got: %s", event.Type())} } - provider := NewAuthEvents(state) - if err = Allowed(event, &provider, input.UserIDQuerier); err != nil { + provider, err := NewAuthEvents(state) + if err != nil { + return nil, spec.Forbidden(err.Error()) + } + if err = Allowed(event, provider, input.UserIDQuerier); err != nil { return nil, spec.Forbidden(err.Error()) } diff --git a/handlejoin_test.go b/handlejoin_test.go index 5c61a7e2..861e99d8 100644 --- a/handlejoin_test.go +++ b/handlejoin_test.go @@ -632,6 +632,7 @@ func TestHandleMakeJoinNilContext(t *testing.T) { }) } +//nolint:unparam func createMemberEventBuilder(roomVersion RoomVersion, sender string, roomID string, stateKey *string, content spec.RawJSON) *EventBuilder { return MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&ProtoEvent{ SenderID: sender, diff --git a/handleleave.go b/handleleave.go index bcd73fa6..3e95329f 100644 --- a/handleleave.go +++ b/handleleave.go @@ -81,8 +81,11 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())} } - provider := NewAuthEvents(stateEvents) - if err := Allowed(event, &provider, input.UserIDQuerier); err != nil { + provider, err := NewAuthEvents(stateEvents) + if err != nil { + return nil, spec.Forbidden(err.Error()) + } + if err = Allowed(event, provider, input.UserIDQuerier); err != nil { return nil, spec.Forbidden(err.Error()) } diff --git a/json.go b/json.go index 86f22ef2..4e27ec35 100644 --- a/json.go +++ b/json.go @@ -18,7 +18,7 @@ package gomatrixserverlib import ( "encoding/binary" "errors" - "sort" + "slices" "strings" "unicode/utf16" "unicode/utf8" @@ -159,40 +159,33 @@ func CanonicalJSONAssumeValid(input []byte) []byte { // by codepoint. The input must be valid JSON. func SortJSON(input, output []byte) []byte { result := gjson.ParseBytes(input) - - RawJSON := RawJSONFromResult(result, input) - return sortJSONValue(result, RawJSON, output) + return sortJSONValue(result, output) } // sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the // raw JSON bytes that gjson.Result points to. -func sortJSONValue(input gjson.Result, inputJSON, output []byte) []byte { +func sortJSONValue(input gjson.Result, output []byte) []byte { if input.IsArray() { - return sortJSONArray(input, inputJSON, output) + return sortJSONArray(input, output) } - if input.IsObject() { - return sortJSONObject(input, inputJSON, output) + return sortJSONObject(input, output) } - // If its neither an object nor an array then there is no sub structure // to sort, so just append the raw bytes. - return append(output, inputJSON...) + return append(output, input.Raw...) } // sortJSONArray takes a gjson.Result and sorts it, assuming its an array. // inputJSON must be the raw JSON bytes that gjson.Result points to. -func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte { +func sortJSONArray(input gjson.Result, output []byte) []byte { sep := byte('[') // Iterate over each value in the array and sort it. input.ForEach(func(_, value gjson.Result) bool { output = append(output, sep) sep = ',' - - RawJSON := RawJSONFromResult(value, inputJSON) - output = sortJSONValue(value, RawJSON, output) - + output = sortJSONValue(value, output) return true // keep iterating }) @@ -209,29 +202,30 @@ func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte { // sortJSONObject takes a gjson.Result and sorts it, assuming its an object. // inputJSON must be the raw JSON bytes that gjson.Result points to. -func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte { +func sortJSONObject(input gjson.Result, output []byte) []byte { type entry struct { - key string // The parsed key string - rawKey []byte // The raw, unparsed key JSON string - value gjson.Result + key string // The parsed key string + value gjson.Result } - var entries []entry + // Try to stay on the stack here if we can. + var _entries [128]entry + entries := _entries[:0] // Iterate over each key/value pair and add it to a slice // that we can sort input.ForEach(func(key, value gjson.Result) bool { entries = append(entries, entry{ - key: key.String(), - rawKey: RawJSONFromResult(key, inputJSON), - value: value, + key: key.String(), + value: value, }) return true // keep iterating }) - // Sort the slice based on the *parsed* key - sort.Slice(entries, func(a, b int) bool { - return entries[a].key < entries[b].key + // Using slices.SortFunc here instead of sort.Slice avoids + // heap escapes due to reflection. + slices.SortFunc(entries, func(a, b entry) int { + return strings.Compare(a.key, b.key) }) sep := byte('{') @@ -241,12 +235,10 @@ func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte { sep = ',' // Append the raw unparsed JSON key, *not* the parsed key - output = append(output, entry.rawKey...) - output = append(output, ':') - - RawJSON := RawJSONFromResult(entry.value, inputJSON) - - output = sortJSONValue(entry.value, RawJSON, output) + output = append(output, '"') + output = append(output, entry.key...) + output = append(output, '"', ':') + output = sortJSONValue(entry.value, output) } if sep == '{' { // If sep is still '{' then the object was empty and we never wrote the @@ -375,10 +367,3 @@ func readHexDigits(input []byte) rune { hex |= hex >> 8 return rune(hex & 0xFFFF) } - -// RawJSONFromResult extracts the raw JSON bytes pointed to by result. -// input must be the json bytes that were used to generate result -// TODO: Why do we do this? -func RawJSONFromResult(result gjson.Result, _ []byte) []byte { - return []byte(result.Raw) -} diff --git a/performinvite.go b/performinvite.go index 7f591a29..a1830d8e 100644 --- a/performinvite.go +++ b/performinvite.go @@ -123,7 +123,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede input.EventTemplate.Depth = latestEvents.Depth - authEvents := NewAuthEvents(nil) + authEvents, _ := NewAuthEvents(nil) for _, event := range latestEvents.StateEvents { err := authEvents.AddEvent(event) @@ -132,7 +132,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede } } - refs, err := stateNeeded.AuthEventReferences(&authEvents) + refs, err := stateNeeded.AuthEventReferences(authEvents) if err != nil { return nil, fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err) } diff --git a/performjoin_test.go b/performjoin_test.go index 7d0e2e3c..7c5395c6 100644 --- a/performjoin_test.go +++ b/performjoin_test.go @@ -13,7 +13,6 @@ import ( "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" - ed255192 "golang.org/x/crypto/ed25519" ) type TestMakeJoinResponse struct { @@ -388,7 +387,7 @@ func TestPerformJoinPseudoID(t *testing.T) { return res, nil } - idCreator := func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed255192.PrivateKey, error) { + idCreator := func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) { return spec.SenderIDFromPseudoIDKey(userPriv), userPriv, nil } diff --git a/stateresolutionv2.go b/stateresolutionv2.go index 8ff346cd..36af1b41 100644 --- a/stateresolutionv2.go +++ b/stateresolutionv2.go @@ -15,10 +15,9 @@ package gomatrixserverlib import ( - "container/heap" "encoding/json" "fmt" - "sort" + "slices" "github.com/matrix-org/gomatrixserverlib/spec" ) @@ -34,7 +33,7 @@ const ( type stateResolverV2 struct { allower *allowerContext // Used to auth and apply events - authProvider AuthEvents // Used in the allower + authProvider *AuthEvents // Used in the allower authEventMap map[string]PDU // Map of all provided auth events conflictedEventMap map[string]PDU // Map of all provided conflicted events powerLevelContents map[string]*PowerLevelContent // A cache of all power level contents @@ -47,6 +46,7 @@ type stateResolverV2 struct { resolvedOthers map[StateKeyTuple]PDU // Resolved other events result []PDU // Final list of resolved events isRejectedFn IsRejected // Check if the given eventID is rejected + isRejectedCache map[string]bool // Events known to be or not be rejected } // IsRejected should return if the given eventID is rejected or not. @@ -64,9 +64,10 @@ func ResolveStateConflictsV2( // Prepare the state resolver. conflictedControlEvents := make([]PDU, 0, len(conflicted)) conflictedOthers := make([]PDU, 0, len(conflicted)) + authProvider, _ := NewAuthEvents(nil) r := stateResolverV2{ authEventMap: eventMapFromEvents(authEvents), - authProvider: NewAuthEvents(nil), + authProvider: authProvider, conflictedEventMap: eventMapFromEvents(conflicted), powerLevelContents: make(map[string]*PowerLevelContent), powerLevelMainlinePos: make(map[string]int), @@ -75,6 +76,7 @@ func ResolveStateConflictsV2( resolvedOthers: make(map[StateKeyTuple]PDU, len(conflicted)), result: make([]PDU, 0, len(conflicted)+len(unconflicted)), isRejectedFn: isRejectedFn, + isRejectedCache: make(map[string]bool), } var roomID *spec.RoomID if len(conflicted) > 0 { @@ -95,7 +97,7 @@ func ResolveStateConflictsV2( return r.result } - r.allower = newAllowerContext(&r.authProvider, userIDForSender, *roomID) + r.allower = newAllowerContext(r.authProvider, userIDForSender, *roomID) // This is a map to help us determine if an event already belongs to the // unconflicted set. If it does then we shouldn't add it back into the @@ -164,13 +166,13 @@ func ResolveStateConflictsV2( // state. We will then keep the successfully authed unconflicted events so that // they can be reapplied later. unconflicted = r.reverseTopologicalOrdering(unconflicted, TopologicalOrderByAuthEvents) - r.applyEvents(unconflicted) + r.applyEvents(unconflicted...) // Then order the conflicted power level events topologically and then also // auth those too. The successfully authed events will be layered on top of // the partial state. conflictedControlEvents = r.reverseTopologicalOrdering(conflictedControlEvents, TopologicalOrderByAuthEvents) - r.authAndApplyEvents(conflictedControlEvents) + r.authAndApplyEvents(conflictedControlEvents...) // Then generate the mainline of power level events, order the remaining state // events based on the mainline ordering and auth those too. The successfully @@ -179,13 +181,13 @@ func ResolveStateConflictsV2( r.powerLevelMainlinePos[event.EventID()] = pos } conflictedOthers = r.mainlineOrdering(conflictedOthers) - r.authAndApplyEvents(conflictedOthers) + r.authAndApplyEvents(conflictedOthers...) // Finally we will reapply the original set of unconflicted events onto the // partial state, just in case any of these were overwritten by pulling in // auth events in the previous two steps, and that gives us our final resolved // state. - r.applyEvents(unconflicted) + r.applyEvents(unconflicted...) // Now that we have our final state, populate the result array with the // resolved state and return it. @@ -439,49 +441,70 @@ func (r *stateResolverV2) getFirstPowerLevelMainlineEvent(event PDU) ( // also apply them on top of the partial state. If they fail auth checks then // the event is ignored and dropped. Returns two lists - the first contains the // accepted (authed) events and the second contains the rejected events. -func (r *stateResolverV2) authAndApplyEvents(events []PDU) { +func (r *stateResolverV2) authAndApplyEvents(events ...PDU) { + addFromAuthEventsIfNotRejected := func(event PDU, eventType, stateKey string) { + for _, authEventID := range event.AuthEventIDs() { + if _, ok := r.isRejectedCache[authEventID]; !ok { + r.isRejectedCache[authEventID] = r.isRejectedFn(authEventID) + } + if rejected := r.isRejectedCache[authEventID]; rejected { + continue + } + authEv, ok := r.authEventMap[authEventID] + if !ok { + continue + } + if authEv.Type() != eventType || !authEv.StateKeyEquals(stateKey) { + continue + } + _ = r.authProvider.AddEvent(event) + } + } + for _, event := range events { r.authProvider.Clear() // Now layer on the partial state events that we do know. This should // mean that we make forward progress. needed := StateNeededForAuth([]PDU{event}) - if event := r.resolvedCreate; needed.Create && event != nil { - _ = r.authProvider.AddEvent(event) + if resolved := r.resolvedCreate; needed.Create { + if resolved != nil { + _ = r.authProvider.AddEvent(resolved) + } else { + addFromAuthEventsIfNotRejected(event, spec.MRoomCreate, "") + } } - if event := r.resolvedJoinRules; needed.JoinRules && event != nil { - _ = r.authProvider.AddEvent(event) + if resolved := r.resolvedJoinRules; needed.JoinRules { + if resolved != nil { + _ = r.authProvider.AddEvent(resolved) + } else { + addFromAuthEventsIfNotRejected(event, spec.MRoomJoinRules, "") + } } - if event := r.resolvedPowerLevels; needed.PowerLevels && event != nil { - _ = r.authProvider.AddEvent(event) + if resolved := r.resolvedPowerLevels; needed.PowerLevels { + if resolved != nil { + _ = r.authProvider.AddEvent(resolved) + } else { + addFromAuthEventsIfNotRejected(event, spec.MRoomPowerLevels, "") + } } for _, needed := range needed.Member { - if membershipEvent := r.resolvedMembers[spec.SenderID(needed)]; membershipEvent != nil { - _ = r.authProvider.AddEvent(membershipEvent) + if resolved := r.resolvedMembers[spec.SenderID(needed)]; resolved != nil { + _ = r.authProvider.AddEvent(resolved) } else { - for _, authEventID := range event.AuthEventIDs() { - authEv, ok := r.authEventMap[authEventID] - if !ok { - continue - } - if authEv.Type() == spec.MRoomMember && authEv.StateKeyEquals(needed) { - // Don't use rejected events for auth - if r.isRejectedFn(authEventID) { - continue - } - _ = r.authProvider.AddEvent(authEv) - } - } + addFromAuthEventsIfNotRejected(event, spec.MRoomMember, needed) } } for _, needed := range needed.ThirdPartyInvite { - if event := r.resolvedThirdPartyInvites[needed]; event != nil { - _ = r.authProvider.AddEvent(event) + if resolved := r.resolvedThirdPartyInvites[needed]; resolved != nil { + _ = r.authProvider.AddEvent(resolved) + } else { + addFromAuthEventsIfNotRejected(event, spec.MRoomThirdPartyInvite, needed) } } // Check if the event is allowed based on the current partial state. - r.allower.update(&r.authProvider) + r.allower.update(r.authProvider) if err := r.allower.allowed(event); err != nil { // The event was not allowed by the partial state and/or relevant // auth events from the event, so skip it. @@ -489,12 +512,12 @@ func (r *stateResolverV2) authAndApplyEvents(events []PDU) { } // Apply the newly authed event to the partial state. We need to do this // here so that the next loop will have partial state to auth against. - r.applyEvents([]PDU{event}) + r.applyEvents(event) } } // applyEvents applies the events on top of the partial state. -func (r *stateResolverV2) applyEvents(events []PDU) { +func (r *stateResolverV2) applyEvents(events ...PDU) { for _, event := range events { if st, sk := event.Type(), event.StateKey(); sk == nil { continue @@ -602,7 +625,7 @@ func (r *stateResolverV2) reverseTopologicalOrdering(events []PDU, order Topolog func (r *stateResolverV2) mainlineOrdering(events []PDU) []PDU { block := r.wrapOtherEventsForSort(events) result := make([]PDU, 0, len(block)) - sort.Sort(stateResV2ConflictedOtherHeap(block)) + slices.SortStableFunc(block, sortStateResV2ConflictedOtherHeap) for _, s := range block { result = append(result, s.event) } @@ -682,19 +705,18 @@ func kahnsAlgorithmUsingAuthEvents(events []*stateResV2ConflictedPowerLevel) []* // dependencies. These will be placed into the graph first. Remove the event // from the event map as this prevents us from processing it a second time. noIncoming := make(stateResV2ConflictedPowerLevelHeap, 0, len(events)) - heap.Init(&noIncoming) for eventID, count := range inDegree { if count == 0 { - heap.Push(&noIncoming, eventMap[eventID]) + noIncoming.Push(eventMap[eventID]) delete(eventMap, eventID) } } + slices.SortStableFunc(noIncoming, sortStateResV2ConflictedPowerLevelHeap) - var event *stateResV2ConflictedPowerLevel - for noIncoming.Len() > 0 { + for len(noIncoming) > 0 { // Pop the first event ID off the list of events which have no incoming // auth event dependencies. - event = heap.Pop(&noIncoming).(*stateResV2ConflictedPowerLevel) + event := noIncoming.Pop() // Since there are no incoming dependencies to resolve, we can now add this // event into the graph. @@ -715,20 +737,21 @@ func kahnsAlgorithmUsingAuthEvents(events []*stateResV2ConflictedPowerLevel) []* // process the outgoing dependencies of this auth event. if inDegree[auth] == 0 { if _, ok := eventMap[auth]; ok { - heap.Push(&noIncoming, eventMap[auth]) + noIncoming.Push(eventMap[auth]) delete(eventMap, auth) } } } + slices.SortStableFunc(noIncoming, sortStateResV2ConflictedPowerLevelHeap) } // If we have stray events left over then add them into the result. if len(eventMap) > 0 { remaining := make(stateResV2ConflictedPowerLevelHeap, 0, len(events)) for _, event := range eventMap { - heap.Push(&remaining, event) + remaining.Push(event) } - sort.Sort(sort.Reverse(remaining)) + slices.SortStableFunc(remaining, sortStateResV2ConflictedPowerLevelHeap) graph = append(remaining, graph...) } @@ -768,19 +791,18 @@ func kahnsAlgorithmUsingPrevEvents(events []*stateResV2ConflictedOther) []*state // dependencies. These will be placed into the graph first. Remove the event // from the event map as this prevents us from processing it a second time. noIncoming := make(stateResV2ConflictedOtherHeap, 0, len(events)) - heap.Init(&noIncoming) for eventID, count := range inDegree { if count == 0 { - heap.Push(&noIncoming, eventMap[eventID]) + noIncoming.Push(eventMap[eventID]) delete(eventMap, eventID) } } + slices.SortStableFunc(noIncoming, sortStateResV2ConflictedOtherHeap) - var event *stateResV2ConflictedOther - for noIncoming.Len() > 0 { + for len(noIncoming) > 0 { // Pop the first event ID off the list of events which have no incoming // prev event dependencies. - event = heap.Pop(&noIncoming).(*stateResV2ConflictedOther) + event := noIncoming.Pop() // Since there are no incoming dependencies to resolve, we can now add this // event into the graph. @@ -801,20 +823,21 @@ func kahnsAlgorithmUsingPrevEvents(events []*stateResV2ConflictedOther) []*state // process the outgoing dependencies of this prev event. if inDegree[prev] == 0 { if _, ok := eventMap[prev]; ok { - heap.Push(&noIncoming, eventMap[prev]) + noIncoming.Push(eventMap[prev]) delete(eventMap, prev) } } } + slices.SortStableFunc(noIncoming, sortStateResV2ConflictedOtherHeap) } // If we have stray events left over then add them into the result. if len(eventMap) > 0 { remaining := make(stateResV2ConflictedOtherHeap, 0, len(events)) for _, event := range eventMap { - heap.Push(&remaining, event) + remaining = append(remaining, event) } - sort.Sort(sort.Reverse(remaining)) + slices.SortStableFunc(remaining, sortStateResV2ConflictedOtherHeap) graph = append(remaining, graph...) } return graph diff --git a/stateresolutionv2_test.go b/stateresolutionv2_test.go index 418fe6b7..1b97d7dc 100644 --- a/stateresolutionv2_test.go +++ b/stateresolutionv2_test.go @@ -15,7 +15,7 @@ package gomatrixserverlib import ( - "sort" + "slices" "testing" "github.com/matrix-org/gomatrixserverlib/spec" @@ -364,12 +364,13 @@ func TestLexicographicalSorting(t *testing.T) { {eventID: "c", powerLevel: 0, originServerTS: 2}, {eventID: "d", powerLevel: 25, originServerTS: 3}, {eventID: "e", powerLevel: 50, originServerTS: 4}, - {eventID: "f", powerLevel: 75, originServerTS: 4}, - {eventID: "g", powerLevel: 100, originServerTS: 5}, + {eventID: "f", powerLevel: 50, originServerTS: 3}, + {eventID: "g", powerLevel: 75, originServerTS: 4}, + {eventID: "h", powerLevel: 100, originServerTS: 5}, } - expected := []string{"g", "f", "e", "d", "c", "b", "a"} + expected := []string{"h", "g", "f", "e", "d", "a", "b", "c"} - sort.Stable(stateResV2ConflictedPowerLevelHeap(input)) + slices.SortStableFunc(input, sortStateResV2ConflictedPowerLevelHeap) t.Log("Results:") for k, v := range input { diff --git a/stateresolutionv2heaps.go b/stateresolutionv2heaps.go index f8192f1b..fa94747b 100644 --- a/stateresolutionv2heaps.go +++ b/stateresolutionv2heaps.go @@ -37,45 +37,35 @@ type stateResV2ConflictedPowerLevel struct { // ensures that the results are deterministic. type stateResV2ConflictedPowerLevelHeap []*stateResV2ConflictedPowerLevel -// Len implements sort.Interface -func (s stateResV2ConflictedPowerLevelHeap) Len() int { - return len(s) -} - // Less implements sort.Interface -func (s stateResV2ConflictedPowerLevelHeap) Less(i, j int) bool { +func sortStateResV2ConflictedPowerLevelHeap(a, b *stateResV2ConflictedPowerLevel) int { // Try to tiebreak on the effective power level - if s[i].powerLevel > s[j].powerLevel { - return true + if a.powerLevel > b.powerLevel { + return -1 } - if s[i].powerLevel < s[j].powerLevel { - return false + if a.powerLevel < b.powerLevel { + return 1 } // If we've reached here then s[i].powerLevel == s[j].powerLevel // so instead try to tiebreak on origin server TS - if s[i].originServerTS < s[j].originServerTS { - return false + if a.originServerTS < b.originServerTS { + return -1 } - if s[i].originServerTS > s[j].originServerTS { - return true + if a.originServerTS > b.originServerTS { + return 1 } // If we've reached here then s[i].originServerTS == s[j].originServerTS // so instead try to tiebreak on a lexicographical comparison of the event ID - return strings.Compare(s[i].eventID[:], s[j].eventID[:]) > 0 -} - -// Swap implements sort.Interface -func (s stateResV2ConflictedPowerLevelHeap) Swap(i, j int) { - s[i], s[j] = s[j], s[i] + return strings.Compare(a.eventID[:], b.eventID[:]) } // Push implements heap.Interface -func (s *stateResV2ConflictedPowerLevelHeap) Push(x interface{}) { - *s = append(*s, x.(*stateResV2ConflictedPowerLevel)) +func (s *stateResV2ConflictedPowerLevelHeap) Push(x *stateResV2ConflictedPowerLevel) { + *s = append(*s, x) } // Pop implements heap.Interface -func (s *stateResV2ConflictedPowerLevelHeap) Pop() interface{} { +func (s *stateResV2ConflictedPowerLevelHeap) Pop() *stateResV2ConflictedPowerLevel { old := *s n := len(old) x := old[n-1] @@ -101,53 +91,42 @@ type stateResV2ConflictedOther struct { // ensures that the results are deterministic. type stateResV2ConflictedOtherHeap []*stateResV2ConflictedOther -// Len implements sort.Interface -func (s stateResV2ConflictedOtherHeap) Len() int { - return len(s) -} - -// Less implements sort.Interface -func (s stateResV2ConflictedOtherHeap) Less(i, j int) bool { +func sortStateResV2ConflictedOtherHeap(a, b *stateResV2ConflictedOther) int { // Try to tiebreak on the mainline position - if s[i].mainlinePosition < s[j].mainlinePosition { - return true + if a.mainlinePosition < b.mainlinePosition { + return -1 } - if s[i].mainlinePosition > s[j].mainlinePosition { - return false + if a.mainlinePosition > b.mainlinePosition { + return 1 } // If we've reached here then s[i].mainlinePosition == s[j].mainlinePosition // so instead try to tiebreak on step count - if s[i].mainlineSteps < s[j].mainlineSteps { - return true + if a.mainlineSteps < b.mainlineSteps { + return -1 } - if s[i].mainlineSteps > s[j].mainlineSteps { - return false + if a.mainlineSteps > b.mainlineSteps { + return 1 } // If we've reached here then s[i].mainlineSteps == s[j].mainlineSteps // so instead try to tiebreak on origin server TS - if s[i].originServerTS < s[j].originServerTS { - return true + if a.originServerTS < b.originServerTS { + return -1 } - if s[i].originServerTS > s[j].originServerTS { - return false + if a.originServerTS > b.originServerTS { + return 1 } // If we've reached here then s[i].originServerTS == s[j].originServerTS // so instead try to tiebreak on a lexicographical comparison of the event ID - return strings.Compare(s[i].eventID[:], s[j].eventID[:]) < 0 -} - -// Swap implements sort.Interface -func (s stateResV2ConflictedOtherHeap) Swap(i, j int) { - s[i], s[j] = s[j], s[i] + return strings.Compare(a.eventID, b.eventID) } // Push implements heap.Interface -func (s *stateResV2ConflictedOtherHeap) Push(x interface{}) { - *s = append(*s, x.(*stateResV2ConflictedOther)) +func (s *stateResV2ConflictedOtherHeap) Push(x *stateResV2ConflictedOther) { + *s = append(*s, x) } // Pop implements heap.Interface -func (s *stateResV2ConflictedOtherHeap) Pop() interface{} { +func (s *stateResV2ConflictedOtherHeap) Pop() *stateResV2ConflictedOther { old := *s n := len(old) x := old[n-1]