Skip to content

Commit

Permalink
Performance and bug-fix backports from Harmony
Browse files Browse the repository at this point in the history
This contains various backports:
* Reduced allocations and performance improvements in state resolution
* Reduced allocations and performance improvements in JSON handling
* Event auth fixes, including correct error surfacing
  • Loading branch information
neilalexander committed Dec 13, 2024
1 parent dbd5f31 commit b3fced3
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 176 deletions.
14 changes: 10 additions & 4 deletions authstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ func checkAllowedByAuthEvents(
event PDU, eventsByID map[string]PDU,
missingAuth EventProvider, userIDForSender spec.UserIDForSender,
) error {
authEvents := NewAuthEvents(nil)
authEvents, err := NewAuthEvents(nil)
if err != nil {
return err
}

for _, ae := range event.AuthEventIDs() {
retryEvent:
Expand Down Expand Up @@ -214,7 +217,7 @@ func checkAllowedByAuthEvents(

// If we made it this far then we've successfully got as many of the auth events as
// as described by AuthEventIDs(). Check if they allow the event.
if err := Allowed(event, &authEvents, userIDForSender); err != nil {
if err := Allowed(event, authEvents, userIDForSender); err != nil {
return fmt.Errorf(
"gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s",
event.EventID(), err.Error(),
Expand Down Expand Up @@ -335,7 +338,10 @@ func CheckSendJoinResponse(
}

eventsByID := map[string]PDU{}
authEventProvider := NewAuthEvents(nil)
authEventProvider, err := NewAuthEvents(nil)
if err != nil {
return nil, err
}

// Since checkAllowedByAuthEvents needs to be able to look up any of the
// auth events by ID only, we will build a map which contains references
Expand Down Expand Up @@ -369,7 +375,7 @@ func CheckSendJoinResponse(
}

// Now check that the join event is valid against the supplied state.
if err := Allowed(joinEvent, &authEventProvider, userIDForSender); err != nil {
if err := Allowed(joinEvent, authEventProvider, userIDForSender); err != nil {
return nil, fmt.Errorf(
"gomatrixserverlib: event with ID %q is not allowed by the current room state: %w",
joinEvent.EventID(), err,
Expand Down
3 changes: 2 additions & 1 deletion backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ func RequestBackfill(ctx context.Context, origin spec.ServerName, b BackfillRequ
}
}

return result, lastErr
// Since we pulled in results from multiple servers we need to sort again...
return ReverseTopologicalOrdering(result, TopologicalOrderByPrevEvents), lastErr
}

/*
Expand Down
3 changes: 1 addition & 2 deletions eventV1.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ func (e *eventV1) SetUnsignedField(path string, value interface{}) error {
eventJSON = CanonicalJSONAssumeValid(eventJSON)

res := gjson.GetBytes(eventJSON, "unsigned")
unsigned := RawJSONFromResult(res, eventJSON)
e.eventFields.Unsigned = unsigned
e.eventFields.Unsigned = []byte(res.Raw)

e.eventJSON = eventJSON

Expand Down
8 changes: 5 additions & 3 deletions eventauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,17 @@ func (a *AuthEvents) Clear() {

// NewAuthEvents returns an AuthEventProvider backed by the given events. New events can be added by
// calling AddEvent().
func NewAuthEvents(events []PDU) AuthEvents {
func NewAuthEvents(events []PDU) (*AuthEvents, error) {
a := AuthEvents{
events: make(map[StateKeyTuple]PDU, len(events)),
roomIDs: make(map[string]struct{}),
}
for _, e := range events {
a.AddEvent(e) // nolint: errcheck
if err := a.AddEvent(e); err != nil {
return nil, err
}
}
return a
return &a, nil
}

// A NotAllowed error is returned if an event does not pass the auth checks.
Expand Down
14 changes: 7 additions & 7 deletions eventauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestStateNeededForMessage(t *testing.T) {
// Message events need the create event, the sender and the power_levels.
testStateNeededForAuth(t, `[{
"type": "m.room.message",
"sender": "@u1:a",
"sender": "@u1:a",
"room_id": "!r1:a"
}]`, &ProtoEvent{
Type: "m.room.message",
Expand Down Expand Up @@ -139,7 +139,7 @@ func TestStateNeededForJoin(t *testing.T) {
"type": "m.room.member",
"state_key": "@u1:a",
"sender": "@u1:a",
"content": {"membership": "join"},
"content": {"membership": "join"},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
Expand All @@ -163,7 +163,7 @@ func TestStateNeededForInvite(t *testing.T) {
"type": "m.room.member",
"state_key": "@u2:b",
"sender": "@u1:a",
"content": {"membership": "invite"},
"content": {"membership": "invite"},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestStateNeededForInvite3PID(t *testing.T) {
"token": "my_token"
}
}
},
},
"room_id": "!r1:a"
}]`, &b, StateNeeded{
Create: true,
Expand Down Expand Up @@ -1035,7 +1035,7 @@ func TestAuthEvents(t *testing.T) {
if err != nil {
t.Fatalf("TestAuthEvents: failed to create power_levels event: %s", err)
}
a := NewAuthEvents([]PDU{power})
a, _ := NewAuthEvents([]PDU{power})
var e PDU
if e, err = a.PowerLevels(); err != nil || e != power {
t.Errorf("TestAuthEvents: failed to get same power_levels event")
Expand Down Expand Up @@ -1685,15 +1685,15 @@ func TestMembershipBanned(t *testing.T) {
"state_key": "@u3:a",
"event_id": "$e4:a",
"content": {"membership": "ban"}
},
},
{
"type": "m.room.member",
"sender": "@u2:a",
"room_id": "!r1:a",
"state_key": "@u3:a",
"event_id": "$e4:a",
"content": {"membership": "ban"}
},
},
{
"type": "m.room.member",
"sender": "@u2:a",
Expand Down
8 changes: 5 additions & 3 deletions handleinvite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv
return nil, fmt.Errorf("failed getting auth provider")
}

eventProvider := AuthEvents{}
eventProvider, _ := NewAuthEvents(nil)
if r.createEvent != nil {
eventProvider = NewAuthEvents([]PDU{r.createEvent})
if err := eventProvider.AddEvent(r.createEvent); err != nil {
return nil, err
}
if r.inviterMemberEvent != nil {
err := eventProvider.AddEvent(r.inviterMemberEvent)
if err != nil {
return nil, err
}
}
}
return &eventProvider, nil
return eventProvider, nil
}

func (r *TestStateQuerier) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []StateKeyTuple) ([]PDU, error) {
Expand Down
7 changes: 5 additions & 2 deletions handlejoin.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ func HandleMakeJoin(input HandleMakeJoinInput) (*HandleMakeJoinResponse, error)
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected join event from template builder. got: %s", event.Type())}
}

provider := NewAuthEvents(state)
if err = Allowed(event, &provider, input.UserIDQuerier); err != nil {
provider, err := NewAuthEvents(state)
if err != nil {
return nil, spec.Forbidden(err.Error())
}
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
return nil, spec.Forbidden(err.Error())
}

Expand Down
1 change: 1 addition & 0 deletions handlejoin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ func TestHandleMakeJoinNilContext(t *testing.T) {
})
}

//nolint:unparam
func createMemberEventBuilder(roomVersion RoomVersion, sender string, roomID string, stateKey *string, content spec.RawJSON) *EventBuilder {
return MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&ProtoEvent{
SenderID: sender,
Expand Down
7 changes: 5 additions & 2 deletions handleleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())}
}

provider := NewAuthEvents(stateEvents)
if err := Allowed(event, &provider, input.UserIDQuerier); err != nil {
provider, err := NewAuthEvents(stateEvents)
if err != nil {
return nil, spec.Forbidden(err.Error())
}
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
return nil, spec.Forbidden(err.Error())
}

Expand Down
63 changes: 24 additions & 39 deletions json.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package gomatrixserverlib
import (
"encoding/binary"
"errors"
"sort"
"slices"

Check failure on line 21 in json.go

View workflow job for this annotation

GitHub Actions / Unit tests (Go 1.20)

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.14/x64/src/slices)
"strings"
"unicode/utf16"
"unicode/utf8"
Expand Down Expand Up @@ -159,40 +159,33 @@ func CanonicalJSONAssumeValid(input []byte) []byte {
// by codepoint. The input must be valid JSON.
func SortJSON(input, output []byte) []byte {
result := gjson.ParseBytes(input)

RawJSON := RawJSONFromResult(result, input)
return sortJSONValue(result, RawJSON, output)
return sortJSONValue(result, output)
}

// sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the
// raw JSON bytes that gjson.Result points to.
func sortJSONValue(input gjson.Result, inputJSON, output []byte) []byte {
func sortJSONValue(input gjson.Result, output []byte) []byte {
if input.IsArray() {
return sortJSONArray(input, inputJSON, output)
return sortJSONArray(input, output)
}

if input.IsObject() {
return sortJSONObject(input, inputJSON, output)
return sortJSONObject(input, output)
}

// If its neither an object nor an array then there is no sub structure
// to sort, so just append the raw bytes.
return append(output, inputJSON...)
return append(output, input.Raw...)
}

// sortJSONArray takes a gjson.Result and sorts it, assuming its an array.
// inputJSON must be the raw JSON bytes that gjson.Result points to.
func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {
func sortJSONArray(input gjson.Result, output []byte) []byte {
sep := byte('[')

// Iterate over each value in the array and sort it.
input.ForEach(func(_, value gjson.Result) bool {
output = append(output, sep)
sep = ','

RawJSON := RawJSONFromResult(value, inputJSON)
output = sortJSONValue(value, RawJSON, output)

output = sortJSONValue(value, output)
return true // keep iterating
})

Expand All @@ -209,29 +202,30 @@ func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {

// sortJSONObject takes a gjson.Result and sorts it, assuming its an object.
// inputJSON must be the raw JSON bytes that gjson.Result points to.
func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
func sortJSONObject(input gjson.Result, output []byte) []byte {
type entry struct {
key string // The parsed key string
rawKey []byte // The raw, unparsed key JSON string
value gjson.Result
key string // The parsed key string
value gjson.Result
}

var entries []entry
// Try to stay on the stack here if we can.
var _entries [128]entry
entries := _entries[:0]

// Iterate over each key/value pair and add it to a slice
// that we can sort
input.ForEach(func(key, value gjson.Result) bool {
entries = append(entries, entry{
key: key.String(),
rawKey: RawJSONFromResult(key, inputJSON),
value: value,
key: key.String(),
value: value,
})
return true // keep iterating
})

// Sort the slice based on the *parsed* key
sort.Slice(entries, func(a, b int) bool {
return entries[a].key < entries[b].key
// Using slices.SortFunc here instead of sort.Slice avoids
// heap escapes due to reflection.
slices.SortFunc(entries, func(a, b entry) int {
return strings.Compare(a.key, b.key)
})

sep := byte('{')
Expand All @@ -241,12 +235,10 @@ func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
sep = ','

// Append the raw unparsed JSON key, *not* the parsed key
output = append(output, entry.rawKey...)
output = append(output, ':')

RawJSON := RawJSONFromResult(entry.value, inputJSON)

output = sortJSONValue(entry.value, RawJSON, output)
output = append(output, '"')
output = append(output, entry.key...)
output = append(output, '"', ':')
output = sortJSONValue(entry.value, output)
}
if sep == '{' {
// If sep is still '{' then the object was empty and we never wrote the
Expand Down Expand Up @@ -375,10 +367,3 @@ func readHexDigits(input []byte) rune {
hex |= hex >> 8
return rune(hex & 0xFFFF)
}

// RawJSONFromResult extracts the raw JSON bytes pointed to by result.
// input must be the json bytes that were used to generate result
// TODO: Why do we do this?
func RawJSONFromResult(result gjson.Result, _ []byte) []byte {
return []byte(result.Raw)
}
4 changes: 2 additions & 2 deletions performinvite.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede

input.EventTemplate.Depth = latestEvents.Depth

authEvents := NewAuthEvents(nil)
authEvents, _ := NewAuthEvents(nil)

for _, event := range latestEvents.StateEvents {
err := authEvents.AddEvent(event)
Expand All @@ -132,7 +132,7 @@ func PerformInvite(ctx context.Context, input PerformInviteInput, fedClient Fede
}
}

refs, err := stateNeeded.AuthEventReferences(&authEvents)
refs, err := stateNeeded.AuthEventReferences(authEvents)
if err != nil {
return nil, fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err)
}
Expand Down
3 changes: 1 addition & 2 deletions performjoin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
ed255192 "golang.org/x/crypto/ed25519"
)

type TestMakeJoinResponse struct {
Expand Down Expand Up @@ -388,7 +387,7 @@ func TestPerformJoinPseudoID(t *testing.T) {
return res, nil
}

idCreator := func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed255192.PrivateKey, error) {
idCreator := func(ctx context.Context, userID spec.UserID, roomID spec.RoomID, roomVersion string) (spec.SenderID, ed25519.PrivateKey, error) {
return spec.SenderIDFromPseudoIDKey(userPriv), userPriv, nil
}

Expand Down
Loading

0 comments on commit b3fced3

Please sign in to comment.