Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context propagation support #783

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 53 additions & 37 deletions internal/net/call/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ const (
idle // can be used for calls, no calls in-flight
active // can be used for calls, some calls in-flight
draining // some calls in-flight, no new calls should be added

hdrLenLen = uint32(4) // size of the header length included in each message
)

var connStateNames = []string{
Expand Down Expand Up @@ -381,9 +383,6 @@ func (rc *reconnectingConnection) Call(ctx context.Context, h MethodKey, arg []b
}

func (rc *reconnectingConnection) callOnce(ctx context.Context, h MethodKey, arg []byte, opts CallOptions) ([]byte, error) {
enc := codegen.NewEncoder()
copy(enc.Grow(len(h)), h[:])

var micros int64
deadline, haveDeadline := ctx.Deadline()
if haveDeadline {
Expand All @@ -398,19 +397,16 @@ func (rc *reconnectingConnection) callOnce(ctx context.Context, h MethodKey, arg
return nil, ctx.Err()
}
}
enc.Int64(micros)

// Send trace information in the header.
writeTraceContext(ctx, enc)

// Send context metadata in the header.
writeContextMetadata(ctx, enc)
// Encode the header.
hdr := encodeHeader(ctx, h, micros)

// Note that we send the header and the payload as follows:
// [header_length][encoded_header][payload]
hdrSlice := make([]byte, 4)
binary.BigEndian.PutUint32(hdrSlice, uint32(len(enc.Data())))
hdrSlice = append(hdrSlice, enc.Data()...)
var hdrLen [hdrLenLen]byte
binary.LittleEndian.PutUint32(hdrLen[:], uint32(len(hdr)))
hdrSlice := hdrLen[:]
hdrSlice = append(hdrSlice, hdr...)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just combine prev 2 lines?
hdrSlice := append(hdrLen[:], hdr...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


rpc := &call{}
rpc.doneSignal = make(chan struct{})
Expand Down Expand Up @@ -950,40 +946,31 @@ func (c *serverConnection) readRequests(ctx context.Context, hmap *HandlerMap, o
// The result (or error) from the handler is sent back to the client over c.
func (c *serverConnection) runHandler(hmap *HandlerMap, id uint64, msg []byte) {
msgLen := uint32(len(msg))
hdrLenEndOffset := uint32(4)
if msgLen < hdrLenEndOffset {
if msgLen < hdrLenLen {
c.shutdown("server handler", fmt.Errorf("missing request header length"))
return
}

// Get the header length.
hdrLen := binary.BigEndian.Uint32(msg[:hdrLenEndOffset])
hdrEndOffset := hdrLenEndOffset + hdrLen
hdrLen := binary.LittleEndian.Uint32(msg[:hdrLenLen])
hdrEndOffset := hdrLenLen + hdrLen
if msgLen < hdrEndOffset {
c.shutdown("server handler", fmt.Errorf("missing request header"))
return
}

// Extract the encoded request header using a decoder.
dec := codegen.NewDecoder(msg[hdrLenEndOffset:hdrEndOffset])
// Extracts header information.
ctx, hkey, micros, sc := decodeHeader(msg[hdrLenLen:hdrEndOffset])

// Extract handler key.
var hkey MethodKey
copy(hkey[:], dec.Read(len(hkey)))

// Extract the method name
// Extracts the method name.
methodName := hmap.names[hkey]
if methodName == "" {
methodName = "handler"
} else {
methodName = logging.ShortenComponent(methodName)
}

// Add deadline information from the header to the context.
ctx := context.Background()
var cancelFunc func()

micros := dec.Int64()
if micros != 0 {
deadline := time.Now().Add(time.Microsecond * time.Duration(micros))
ctx, cancelFunc = context.WithDeadline(ctx, deadline)
Expand All @@ -996,22 +983,17 @@ func (c *serverConnection) runHandler(hmap *HandlerMap, id uint64, msg []byte) {
}
}()

// Extract trace context and create a new child span to trace the method
// call on the server.
// Create a new child span to trace the method call on the server.
span := trace.SpanFromContext(ctx) // noop span
if sc := readTraceContext(dec); sc != nil {
if sc.IsValid() {
ctx, span = c.opts.Tracer.Start(trace.ContextWithSpanContext(ctx, *sc), methodName, trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
} else {
if sc != nil {
if !sc.IsValid() {
c.shutdown("server handler", fmt.Errorf("invalid span context"))
return
}
ctx, span = c.opts.Tracer.Start(trace.ContextWithSpanContext(ctx, *sc), methodName, trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
}

// Extract metadata context information if any.
ctx = readContextMetadata(ctx, dec)

// Call the handler passing it the payload.
payload := msg[hdrEndOffset:]
var err error
Expand Down Expand Up @@ -1079,6 +1061,40 @@ func (c *serverConnection) shutdown(details string, err error) {
}
}

// encodeHeader encodes the header information that is propagated by each message.
func encodeHeader(ctx context.Context, h MethodKey, micros int64) []byte {
enc := codegen.NewEncoder()
copy(enc.Grow(len(h)), h[:])
enc.Int64(micros)

// Send trace information in the header.
writeTraceContext(ctx, enc)

// Send context metadata in the header.
writeContextMetadata(ctx, enc)

return enc.Data()
}

// decodeHeader extracts the encoded header information.
func decodeHeader(hdr []byte) (context.Context, MethodKey, int64, *trace.SpanContext) {
dec := codegen.NewDecoder(hdr)

// Extract handler key.
var hkey MethodKey
copy(hkey[:], dec.Read(len(hkey)))

// Extract deadline information.
micros := dec.Int64()

// Extract trace context information.
sc := readTraceContext(dec)

// Extract metadata context information if any.
ctx := readContextMetadata(context.Background(), dec)
return ctx, hkey, micros, sc
}

func logError(logger *slog.Logger, details string, err error) {
if errors.Is(err, context.Canceled) ||
errors.Is(err, io.EOF) ||
Expand Down
20 changes: 13 additions & 7 deletions internal/net/call/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,19 @@ const currentVersion = initialVersion
// version [4]byte
//
// requestMessage:
// headerLen [4]byte -- length of the encoded header
// header [length]byte -- encoded header information
// headerKey [16]byte -- fingerprint of method name
// deadline [8]byte -- zero, or deadline in microseconds
// traceContext [25]byte -- zero, or trace context
// metadataContext [length]byte -- encoded map[string]string
// remainder -- call argument serialization
// headerLen [4]byte -- length of the encoded header
// header [headerLen]byte -- encoded header information
// payload -- call argument serialization
//
// The header is encoded using Service Weaver's encoding format for a type that
// looks like:
//
// struct header {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line up the types just like "go fmt" would.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// MethodKey [16]byte
// Deadline int64
// TraceContext [25]byte
// MetadataContext map[string]string
// }
//
// responseMessage:
// payload holds call result serialization
Expand Down
3 changes: 1 addition & 2 deletions weavertest/internal/simple/simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ func (d *destination) GetAll(_ context.Context, file string) ([]string, error) {
func (d *destination) UpdateMetadata(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
meta, found := metadata.FromContext(ctx)
if found {
if meta, found := metadata.FromContext(ctx); found {
d.metadata = maps.Clone(meta)
}
return nil
Expand Down
Loading