Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move concat & concat message extra by ConcatItems #66

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .testcoverage.yml
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ local-prefix: "github.com/cloudwego/eino"
threshold:
# (optional; default 0)
# Minimum overall project coverage percentage required.
total: 75
total: 83

package: 30

270 changes: 3 additions & 267 deletions compose/stream_concat.go
Original file line number Diff line number Diff line change
@@ -19,77 +19,11 @@ package compose
import (
"fmt"
"io"
"reflect"
"strings"

"github.com/cloudwego/eino/internal"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/utils/generic"
)

var (
concatFuncs = map[reflect.Type]any{
generic.TypeOf[*schema.Message](): schema.ConcatMessages,
generic.TypeOf[string](): concatStrings,
generic.TypeOf[[]*schema.Message](): concatMessageArray,
}
)

func concatStrings(ss []string) (string, error) {
var n int
for _, s := range ss {
n += len(s)
}

var b strings.Builder
b.Grow(n)
for _, s := range ss {
_, err := b.WriteString(s)
if err != nil {
return "", err
}
}

return b.String(), nil
}

func concatMessageArray(mas [][]*schema.Message) ([]*schema.Message, error) {
arrayLen := len(mas[0])

ret := make([]*schema.Message, arrayLen)
slicesToConcat := make([][]*schema.Message, arrayLen)

for _, ma := range mas {
if len(ma) != arrayLen {
return nil, fmt.Errorf("unexpected array length. "+
"Got %d, expected %d", len(ma), arrayLen)
}

for i := 0; i < arrayLen; i++ {
m := ma[i]
if m != nil {
slicesToConcat[i] = append(slicesToConcat[i], m)
}
}
}

for i, slice := range slicesToConcat {
if len(slice) == 0 {
ret[i] = nil
} else if len(slice) == 1 {
ret[i] = slice[0]
} else {
cm, err := schema.ConcatMessages(slice)
if err != nil {
return nil, err
}

ret[i] = cm
}
}

return ret, nil
}

// RegisterStreamChunkConcatFunc registers a function to concat stream chunks.
// It's required when you want to concat stream chunks of a specific type.
// for example you call Invoke() but node only implements Stream().
@@ -109,41 +43,7 @@ func concatMessageArray(mas [][]*schema.Message) ([]*schema.Message, error) {
// }, nil
// })
func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) {
concatFuncs[generic.TypeOf[T]()] = fn
}

func getConcatFunc(tpe reflect.Type) func(reflect.Value) (reflect.Value, error) {
if fn, ok := concatFuncs[tpe]; ok {
return func(a reflect.Value) (reflect.Value, error) {
rvs := reflect.ValueOf(fn).Call([]reflect.Value{a})
var err error
if !rvs[1].IsNil() {
err = rvs[1].Interface().(error)
}
return rvs[0], err
}
}

return nil
}

func toSliceValue(vs []any) (reflect.Value, error) {
typ := reflect.TypeOf(vs[0])

ret := reflect.MakeSlice(reflect.SliceOf(typ), len(vs), len(vs))
ret.Index(0).Set(reflect.ValueOf(vs[0]))

for i := 1; i < len(vs); i++ {
v := vs[i]
vt := reflect.TypeOf(v)
if typ != vt {
return reflect.Value{}, fmt.Errorf("unexpected slice element type. Got %v, expected %v", typ, vt)
}

ret.Index(i).Set(reflect.ValueOf(v))
}

return ret, nil
internal.RegisterStreamChunkConcatFunc(fn)
}

func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) {
@@ -174,174 +74,10 @@ func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) {
return items[0], nil
}

res, err := concatItems(items)
res, err := internal.ConcatItems(items)
if err != nil {
var t T
return t, err
}
return res, nil
}

// the caller should ensure len(items) > 1
func concatItems[T any](items []T) (T, error) {
typ := generic.TypeOf[T]()
v := reflect.ValueOf(items)

var cv reflect.Value
var err error

// handle map kind
if typ.Kind() == reflect.Map {
cv, err = concatMaps(v)
} else {
cv, err = concatSliceValue(v)
}

if err != nil {
var t T
return t, err
}

return cv.Interface().(T), nil
}

func concatMaps(ms reflect.Value) (reflect.Value, error) {
typ := ms.Type().Elem()

rms := reflect.MakeMap(reflect.MapOf(typ.Key(), generic.TypeOf[[]any]()))
ret := reflect.MakeMap(typ)

n := ms.Len()
for i := 0; i < n; i++ {
m := ms.Index(i)

for _, key := range m.MapKeys() {
vals := rms.MapIndex(key)
if !vals.IsValid() {
var s []any
vals = reflect.ValueOf(s)
}

val := m.MapIndex(key)
vals = reflect.Append(vals, val)
rms.SetMapIndex(key, vals)
}
}

for _, key := range rms.MapKeys() {
vals := rms.MapIndex(key)

anyVals := vals.Interface().([]any)
v, err := toSliceValue(anyVals)
if err != nil {
return reflect.Value{}, err
}

var cv reflect.Value

if v.Type().Elem().Kind() == reflect.Map {
cv, err = concatMaps(v)
} else {
cv, err = concatSliceValue(v)
}

if err != nil {
return reflect.Value{}, err
}

ret.SetMapIndex(key, cv)
}

return ret, nil
}

func concatSliceValue(val reflect.Value) (reflect.Value, error) {
elmType := val.Type().Elem()

if val.Len() == 1 {
return val.Index(0), nil
}

f := getConcatFunc(elmType)
if f != nil {
return f(val)
}

var (
structType reflect.Type
isStructPtr bool
)

if elmType.Kind() == reflect.Struct {
structType = elmType
} else if elmType.Kind() == reflect.Pointer && elmType.Elem().Kind() == reflect.Struct {
isStructPtr = true
structType = elmType.Elem()
}

if structType != nil {
maps := make([]map[string]any, 0, val.Len())
for i := 0; i < val.Len(); i++ {
sliceElem := val.Index(i)
m, err := structToMap(sliceElem)
if err != nil {
return reflect.Value{}, err
}

maps = append(maps, m)
}

result, err := concatMaps(reflect.ValueOf(maps))
if err != nil {
return reflect.Value{}, err
}

return mapToStruct(result.Interface().(map[string]any), structType, isStructPtr), nil
}

var filtered reflect.Value
for i := 0; i < val.Len(); i++ {
oneVal := val.Index(i)
if !oneVal.IsZero() {
if filtered.IsValid() {
return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType)
}

filtered = oneVal
}
}

return filtered, nil
}

func structToMap(s reflect.Value) (map[string]any, error) {
if s.Kind() == reflect.Ptr {
s = s.Elem()
}

ret := make(map[string]any, s.NumField())
for i := 0; i < s.NumField(); i++ {
fieldType := s.Type().Field(i)
if !fieldType.IsExported() {
return nil, fmt.Errorf("structToMap: field %s is not exported", fieldType.Name)
}

ret[fieldType.Name] = s.Field(i).Interface()
}

return ret, nil
}

func mapToStruct(m map[string]any, t reflect.Type, toPtr bool) reflect.Value {
ret := reflect.New(t).Elem()
for k, v := range m {
field := ret.FieldByName(k)
field.Set(reflect.ValueOf(v))
}

if toPtr {
ret = ret.Addr()
}

return ret
}
23 changes: 20 additions & 3 deletions compose/stream_concat_test.go
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ import (

"github.com/stretchr/testify/assert"

"github.com/cloudwego/eino/internal"
"github.com/cloudwego/eino/schema"
)

@@ -112,7 +113,7 @@ func TestMessageConcat(t *testing.T) {
assert.Equal(t, "0123456789", lastVal.Content)
assert.Len(t, lastVal.Extra, 4)
assert.Equal(t, map[string]any{
"key_1": "8",
"key_1": "048",
"0": "0",
"4": "4",
"8": "8",
@@ -193,14 +194,30 @@ func TestConcatError(t *testing.T) {
"str": "string_02",
"x": 123,
}
_, err := concatItems([]map[string]any{a, b})
_, err := internal.ConcatItems([]map[string]any{a, b})
assert.NotNil(t, err)
})

t.Run("merge error", func(t *testing.T) {
RegisterStreamChunkConcatFunc(concatTStreamError)

_, err := concatItems([]tConcatErrForTest{{}, {}})
_, err := internal.ConcatItems([]tConcatErrForTest{{}, {}})
assert.NotNil(t, err)
})
}

func TestConcatSliceValue(t *testing.T) {
type testStruct struct {
A string
}

s := []testStruct{{}, {A: "123"}, {}}
result, err := internal.ConcatItems(s)
assert.Nil(t, err)
assert.Equal(t, testStruct{A: "123"}, result)

s = []testStruct{{}, {}, {}}
result, err = internal.ConcatItems(s)
assert.Nil(t, err)
assert.Equal(t, testStruct{}, result)
}
3 changes: 2 additions & 1 deletion compose/tool_node_test.go
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@ import (
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
"github.com/cloudwego/eino/internal"
"github.com/cloudwego/eino/schema"
)

@@ -410,7 +411,7 @@ func TestToolsNodeOptions(t *testing.T) {
}
outStream.Close()

msgs, err := concatMessageArray(outMessages)
msgs, err := internal.ConcatItems(outMessages)
assert.NoError(t, err)

assert.Len(t, msgs, 1)
11 changes: 11 additions & 0 deletions compose/workflow_test.go
Original file line number Diff line number Diff line change
@@ -63,6 +63,17 @@ func TestWorkflow(t *testing.T) {
B int
StateTemp string
}
RegisterStreamChunkConcatFunc(func(ts []*structF) (*structF, error) {
ret := &structF{}
for _, tt := range ts {
ret.Field1 += tt.Field1
ret.Field2 += tt.Field2
ret.Field3 = append(ret.Field3, tt.Field3...)
ret.B += tt.B
ret.StateTemp += tt.StateTemp
}
return ret, nil
})

type state struct {
temp string
Loading