Skip to content

Commit

Permalink
sketch of go1.23 push iterators
Browse files Browse the repository at this point in the history
This change defines two helpers, Range and MapRange, that
adapt the Iterable and IterableMapping interfaces to
one- and two-variable range loops, and updates a variety
of calls to use them.

Pros:
- more concise looping syntax.
- no need to worry about Done.
Cons:
- possibly less efficient? Need to measure.
- no "i int" index variable for free.

Naming suggestions welcome.

Obviously we can't submit this for a couple of years,
though we could add the new API behind a go1.23 build tag.
  • Loading branch information
adonovan committed Feb 19, 2024
1 parent 2232540 commit c3c9960
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 102 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module go.starlark.net

go 1.18
go 1.23

require (
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e
Expand Down
7 changes: 3 additions & 4 deletions lib/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,15 @@ func encode(thread *starlark.Thread, b *starlark.Builtin, args starlark.Tuple, k
case starlark.Iterable:
// e.g. tuple, list
buf.WriteByte('[')
iter := x.Iterate()
defer iter.Done()
var elem starlark.Value
for i := 0; iter.Next(&elem); i++ {
i := 0
for elem := range starlark.Range(x) {
if i > 0 {
buf.WriteByte(',')
}
if err := emit(elem); err != nil {
return fmt.Errorf("at %s index %d: %v", x.Type(), i, err)
}
i++
}
buf.WriteByte(']')

Expand Down
88 changes: 34 additions & 54 deletions lib/proto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
//
// This package defines several types of Starlark value:
//
// Message -- a protocol message
// RepeatedField -- a repeated field of a message, like a list
// Message -- a protocol message
// RepeatedField -- a repeated field of a message, like a list
//
// FileDescriptor -- information about a .proto file
// FieldDescriptor -- information about a message field (or extension field)
// MessageDescriptor -- information about the type of a message
// EnumDescriptor -- information about an enumerated type
// EnumValueDescriptor -- a value of an enumerated type
// FileDescriptor -- information about a .proto file
// FieldDescriptor -- information about a message field (or extension field)
// MessageDescriptor -- information about the type of a message
// EnumDescriptor -- information about an enumerated type
// EnumValueDescriptor -- a value of an enumerated type
//
// A Message value is a wrapper around a protocol message instance.
// Starlark programs may access and update Messages using dot notation:
//
// x = msg.field
// msg.field = x + 1
// msg.field += 1
// x = msg.field
// msg.field = x + 1
// msg.field += 1
//
// Assignments to message fields perform dynamic checks on the type and
// range of the value to ensure that the message is at all times valid.
Expand All @@ -35,31 +35,30 @@
// performs a dynamic check to ensure that the RepeatedField holds
// only elements of the correct type.
//
// type(msg.uint32s) # "proto.repeated<uint32>"
// msg.uint32s[0] = 1
// msg.uint32s[0] = -1 # error: invalid uint32: -1
// type(msg.uint32s) # "proto.repeated<uint32>"
// msg.uint32s[0] = 1
// msg.uint32s[0] = -1 # error: invalid uint32: -1
//
// Any iterable may be assigned to a repeated field of a message. If
// the iterable is itself a value of type RepeatedField, the message
// field holds a reference to it.
//
// msg2.uint32s = msg.uint32s # both messages share one RepeatedField
// msg.uint32s[0] = 123
// print(msg2.uint32s[0]) # "123"
// msg2.uint32s = msg.uint32s # both messages share one RepeatedField
// msg.uint32s[0] = 123
// print(msg2.uint32s[0]) # "123"
//
// The RepeatedFields' element types must match.
// It is not enough for the values to be merely valid:
//
// msg.uint32s = [1, 2, 3] # makes a copy
// msg.uint64s = msg.uint32s # error: repeated field has wrong type
// msg.uint64s = list(msg.uint32s) # ok; makes a copy
// msg.uint32s = [1, 2, 3] # makes a copy
// msg.uint64s = msg.uint32s # error: repeated field has wrong type
// msg.uint64s = list(msg.uint32s) # ok; makes a copy
//
// For all other iterables, a new RepeatedField is constructed from the
// elements of the iterable.
//
// msg.uints32s = [1, 2, 3]
// print(type(msg.uints32s)) # "proto.repeated<uint32>"
//
// msg.uints32s = [1, 2, 3]
// print(type(msg.uints32s)) # "proto.repeated<uint32>"
//
// To construct a Message from encoded binary or text data, call
// Unmarshal or UnmarshalText. These two functions are exposed to
Expand All @@ -75,7 +74,6 @@
//
// See proto_test.go for an example of how to use the 'proto'
// module in an application that embeds Starlark.
//
package proto

// TODO(adonovan): Go and Starlark API improvements:
Expand Down Expand Up @@ -111,8 +109,8 @@ import (
// for a Starlark thread to use this package.
//
// For example:
// SetPool(thread, protoregistry.GlobalFiles)
//
// SetPool(thread, protoregistry.GlobalFiles)
func SetPool(thread *starlark.Thread, pool DescriptorPool) {
thread.SetLocal(contextKey, pool)
}
Expand Down Expand Up @@ -305,10 +303,9 @@ func getFieldStarlark(thread *starlark.Thread, fn *starlark.Builtin, args starla
// When a message descriptor is called, it returns a new instance of the
// protocol message it describes.
//
// Message(msg) -- return a shallow copy of an existing message
// Message(k=v, ...) -- return a new message with the specified fields
// Message(dict(...)) -- return a new message with the specified fields
//
// Message(msg) -- return a shallow copy of an existing message
// Message(k=v, ...) -- return a new message with the specified fields
// Message(dict(...)) -- return a new message with the specified fields
func (d MessageDescriptor) CallInternal(thread *starlark.Thread, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
dest := &Message{
msg: newMessage(d.Desc),
Expand Down Expand Up @@ -389,19 +386,17 @@ func setField(msg protoreflect.Message, fdesc protoreflect.FieldDescriptor, valu
// x = []; msg.x = x; y = msg.x
// causes x and y not to alias.
if fdesc.IsList() {
iter := starlark.Iterate(value)
if iter == nil {
iterable, ok := value.(starlark.Iterable)
if !ok {
return fmt.Errorf("got %s for .%s field, want iterable", value.Type(), fdesc.Name())
}
defer iter.Done()

list := msg.Mutable(fdesc).List()
list.Truncate(0)
var x starlark.Value
for i := 0; iter.Next(&x); i++ {
for x := range starlark.Range(iterable) {
v, err := toProto(fdesc, x)
if err != nil {
return fmt.Errorf("index %d: %v", i, err)
return fmt.Errorf("index %d: %v", list.Len(), err)
}
list.Append(v)
}
Expand All @@ -414,28 +409,14 @@ func setField(msg protoreflect.Message, fdesc protoreflect.FieldDescriptor, valu
return fmt.Errorf("in map field %s: expected mappable type, but got %s", fdesc.Name(), value.Type())
}

iter := mapping.Iterate()
defer iter.Done()

// Each value is converted using toProto as usual, passing the key/value
// field descriptors to check their types.
mutMap := msg.Mutable(fdesc).Map()
var k starlark.Value
for iter.Next(&k) {
for k, v := range starlark.MapRange(mapping) {
kproto, err := toProto(fdesc.MapKey(), k)
if err != nil {
return fmt.Errorf("in key of map field %s: %w", fdesc.Name(), err)
}

// `found` is discarded, as the presence of the key in the
// iterator guarantees the presence of some value (even if it is
// starlark.None). Mismatching values will be caught in toProto
// below.
v, _, err := mapping.Get(k)
if err != nil {
return fmt.Errorf("in map field %s, at key %s: %w", fdesc.Name(), k.String(), err)
}

vproto, err := toProto(fdesc.MapValue(), v)
if err != nil {
return fmt.Errorf("in map field %s, at key %s: %w", fdesc.Name(), k.String(), err)
Expand Down Expand Up @@ -1219,11 +1200,10 @@ func enumValueOf(enum protoreflect.EnumDescriptor, x starlark.Value) (protorefle
//
// An EnumValueDescriptor has the following fields:
//
// index -- int, index of this value within the enum sequence
// name -- string, name of this enum value
// number -- int, numeric value of this enum value
// type -- EnumDescriptor, the enum type to which this value belongs
//
// index -- int, index of this value within the enum sequence
// name -- string, name of this enum value
// number -- int, numeric value of this enum value
// type -- EnumDescriptor, the enum type to which this value belongs
type EnumValueDescriptor struct {
Desc protoreflect.EnumValueDescriptor
}
Expand Down
60 changes: 18 additions & 42 deletions starlark/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,7 @@ func all(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, error)
if err := UnpackPositionalArgs("all", args, kwargs, 1, &iterable); err != nil {
return nil, err
}
iter := iterable.Iterate()
defer iter.Done()
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
if !x.Truth() {
return False, nil
}
Expand All @@ -215,10 +212,7 @@ func any(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, error)
if err := UnpackPositionalArgs("any", args, kwargs, 1, &iterable); err != nil {
return nil, err
}
iter := iterable.Iterate()
defer iter.Done()
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
if x.Truth() {
return True, nil
}
Expand Down Expand Up @@ -256,13 +250,10 @@ func bytes_(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, erro
// common case: known length
buf.Grow(n)
}
iter := x.Iterate()
defer iter.Done()
var elem Value
var b byte
for i := 0; iter.Next(&elem); i++ {
for elem := range Range(x) {
if err := AsInt(elem, &b); err != nil {
return nil, fmt.Errorf("bytes: at index %d, %s", i, err)
return nil, fmt.Errorf("bytes: at index %d, %s", buf.Len(), err)
}
buf.WriteByte(b)
}
Expand Down Expand Up @@ -683,13 +674,10 @@ func list(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, error)
}
var elems []Value
if iterable != nil {
iter := iterable.Iterate()
defer iter.Done()
if n := Len(iterable); n > 0 {
elems = make([]Value, 0, n) // preallocate if length known
}
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
elems = append(elems, x)
}
}
Expand Down Expand Up @@ -969,14 +957,11 @@ func reversed(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, er
if err := UnpackPositionalArgs("reversed", args, kwargs, 1, &iterable); err != nil {
return nil, err
}
iter := iterable.Iterate()
defer iter.Done()
var elems []Value
if n := Len(args[0]); n >= 0 {
elems = make([]Value, 0, n) // preallocate if length known
}
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
elems = append(elems, x)
}
n := len(elems)
Expand All @@ -994,10 +979,7 @@ func set(thread *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error)
}
set := new(Set)
if iterable != nil {
iter := iterable.Iterate()
defer iter.Done()
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
if err := set.Insert(x); err != nil {
return nil, nameErr(b, err)
}
Expand All @@ -1020,14 +1002,11 @@ func sorted(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, erro
return nil, err
}

iter := iterable.Iterate()
defer iter.Done()
var values []Value
if n := Len(iterable); n > 0 {
values = make(Tuple, 0, n) // preallocate if length is known
}
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
values = append(values, x)
}

Expand Down Expand Up @@ -1120,14 +1099,11 @@ func tuple(thread *Thread, _ *Builtin, args Tuple, kwargs []Tuple) (Value, error
if len(args) == 0 {
return Tuple(nil), nil
}
iter := iterable.Iterate()
defer iter.Done()
var elems Tuple
if n := Len(iterable); n > 0 {
elems = make(Tuple, 0, n) // preallocate if length is known
}
var x Value
for iter.Next(&x) {
for x := range Range(iterable) {
elems = append(elems, x)
}
return elems, nil
Expand Down Expand Up @@ -1847,11 +1823,9 @@ func string_join(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, erro
if err := UnpackPositionalArgs(b.Name(), args, kwargs, 1, &iterable); err != nil {
return nil, err
}
iter := iterable.Iterate()
defer iter.Done()
buf := new(strings.Builder)
var x Value
for i := 0; iter.Next(&x); i++ {
i := 0
for x := range Range(iterable) {
if i > 0 {
buf.WriteString(recv)
}
Expand All @@ -1860,6 +1834,7 @@ func string_join(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, erro
return nil, fmt.Errorf("join: in list, want string, got %s", x.Type())
}
buf.WriteString(s)
i++
}
return String(buf.String()), nil
}
Expand Down Expand Up @@ -2396,13 +2371,13 @@ func updateDict(dict *Dict, updates Tuple, kwargs []Tuple) error {
}
default:
// all other sequences
iter := Iterate(updates)
if iter == nil {
iterable, ok := updates.(Iterable)
if !ok {
return fmt.Errorf("got %s, want iterable", updates.Type())
}
defer iter.Done()
var pair Value
for i := 0; iter.Next(&pair); i++ {
i := 0
for pair := range Range(iterable) {
// TODO(adonovan): opt: specialize common case: pair is 2-tuple.
iter2 := Iterate(pair)
if iter2 == nil {
return fmt.Errorf("dictionary update sequence element #%d is not iterable (%s)", i, pair.Type())
Expand All @@ -2421,6 +2396,7 @@ func updateDict(dict *Dict, updates Tuple, kwargs []Tuple) error {
if err := dict.SetKey(k, v); err != nil {
return err
}
i++
}
}
}
Expand Down
Loading

0 comments on commit c3c9960

Please sign in to comment.