diff --git a/plugin/storage/memory/memory.go b/plugin/storage/memory/memory.go index ee5e1643cee..8e6516f5853 100644 --- a/plugin/storage/memory/memory.go +++ b/plugin/storage/memory/memory.go @@ -170,11 +170,37 @@ func (m *Store) GetTrace(ctx context.Context, traceID model.TraceID) (*model.Tra // Spans may still be added to traces after they are returned to user code, so make copies. func (m *Store) copyTrace(trace *model.Trace) *model.Trace { return &model.Trace{ - Spans: append([]*model.Span(nil), trace.Spans...), + Spans: m.copySpans(trace), Warnings: append([]string(nil), trace.Warnings...), } } +func (m *Store) copySpans(trace *model.Trace) []*model.Span { + spans := make([]*model.Span, len(trace.Spans)) + for i := range trace.Spans { + spans[i] = m.copySpan(trace.Spans[i]) + } + return spans +} + +// Copy static span attributes and discard the ones that can be altered outside of the store. +func (m *Store) copySpan(span *model.Span) *model.Span { + return &model.Span{ + TraceID: span.TraceID, + SpanID: span.SpanID, + OperationName: span.OperationName, + References: span.References, + Flags: span.Flags, + StartTime: span.StartTime, + Duration: span.Duration, + Tags: span.Tags, + Logs: span.Logs, + Process: span.Process, + ProcessID: span.ProcessID, + Warnings: []string(nil), + } +} + // GetServices returns a list of all known services func (m *Store) GetServices(ctx context.Context) ([]string, error) { m.RLock() diff --git a/plugin/storage/memory/memory_test.go b/plugin/storage/memory/memory_test.go index 104e41a41c4..e296bd90012 100644 --- a/plugin/storage/memory/memory_test.go +++ b/plugin/storage/memory/memory_test.go @@ -210,6 +210,24 @@ func TestStoreGetTraceSuccess(t *testing.T) { }) } +func TestStoreGetAndMutateTrace(t *testing.T) { + withPopulatedMemoryStore(func(store *Store) { + trace, err := store.GetTrace(context.Background(), testingSpan.TraceID) + assert.NoError(t, err) + assert.Len(t, trace.Spans, 1) + assert.Equal(t, testingSpan, trace.Spans[0]) + assert.Len(t, trace.Spans[0].Warnings, 0) + + trace.Spans[0].Warnings = append(trace.Spans[0].Warnings, "the end is near") + + trace, err = store.GetTrace(context.Background(), testingSpan.TraceID) + assert.NoError(t, err) + assert.Len(t, trace.Spans, 1) + assert.Equal(t, testingSpan, trace.Spans[0]) + assert.Len(t, trace.Spans[0].Warnings, 0) + }) +} + func TestStoreGetTraceFailure(t *testing.T) { withPopulatedMemoryStore(func(store *Store) { trace, err := store.GetTrace(context.Background(), model.TraceID{})