diff --git a/README.md b/README.md index 380bb36bb..5a5df6370 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ representations. (There is an example of this below.) Spec: https://github.com/mojombo/toml Compatible with TOML version -[v0.2.0](https://github.com/mojombo/toml/blob/master/versions/toml-v0.2.0.md) +[v0.2.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.2.0.md) Documentation: http://godoc.org/github.com/BurntSushi/toml @@ -111,7 +111,7 @@ type songs struct { Song []song } var favorites songs -if _, err := Decode(blob, &favorites); err != nil { +if _, err := toml.Decode(blob, &favorites); err != nil { log.Fatal(err) } diff --git a/_examples/example.go b/_examples/example.go index c81d25a52..79f31f275 100644 --- a/_examples/example.go +++ b/_examples/example.go @@ -48,9 +48,11 @@ func main() { fmt.Printf("Title: %s\n", config.Title) fmt.Printf("Owner: %s (%s, %s), Born: %s\n", - config.Owner.Name, config.Owner.Org, config.Owner.Bio, config.Owner.DOB) + config.Owner.Name, config.Owner.Org, config.Owner.Bio, + config.Owner.DOB) fmt.Printf("Database: %s %v (Max conn. %d), Enabled? %v\n", - config.DB.Server, config.DB.Ports, config.DB.ConnMax, config.DB.Enabled) + config.DB.Server, config.DB.Ports, config.DB.ConnMax, + config.DB.Enabled) for serverName, server := range config.Servers { fmt.Printf("Server: %s (%s, %s)\n", serverName, server.IP, server.DC) } diff --git a/decode.go b/decode.go index b6d75d042..c26b00c01 100644 --- a/decode.go +++ b/decode.go @@ -12,6 +12,18 @@ import ( var e = fmt.Errorf +// Unmarshaler is the interface implemented by objects that can unmarshal a +// TOML description of themselves. +type Unmarshaler interface { + UnmarshalTOML(interface{}) error +} + +// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`. +func Unmarshal(p []byte, v interface{}) error { + _, err := Decode(string(p), v) + return err +} + // Primitive is a TOML value that hasn't been decoded into a Go value. // When using the various `Decode*` functions, the type `Primitive` may // be given to any value, and its decoding will be delayed. @@ -128,6 +140,7 @@ func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { // Any type mismatch produces an error. Finding a type that we don't know // how to handle produces an unsupported type error. func (md *MetaData) unify(data interface{}, rv reflect.Value) error { + // Special case. Look for a `Primitive` value. if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() { // Save the undecoded data and the key context into the primitive @@ -141,6 +154,13 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error { return nil } + // Special case. Unmarshaler Interface support. + if rv.CanAddr() { + if v, ok := rv.Addr().Interface().(Unmarshaler); ok { + return v.UnmarshalTOML(data) + } + } + // Special case. Handle time.Time values specifically. // TODO: Remove this code when we decide to drop support for Go 1.1. // This isn't necessary in Go 1.2 because time.Time satisfies the encoding @@ -205,6 +225,9 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error { func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { tmap, ok := mapping.(map[string]interface{}) if !ok { + if mapping == nil { + return nil + } return mismatch(rv, "map", mapping) } @@ -247,6 +270,9 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { tmap, ok := mapping.(map[string]interface{}) if !ok { + if tmap == nil { + return nil + } return badtype("map", mapping) } if rv.IsNil() { @@ -272,6 +298,9 @@ func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { datav := reflect.ValueOf(data) if datav.Kind() != reflect.Slice { + if !datav.IsValid() { + return nil + } return badtype("slice", data) } sliceLen := datav.Len() @@ -285,12 +314,16 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { datav := reflect.ValueOf(data) if datav.Kind() != reflect.Slice { + if !datav.IsValid() { + return nil + } return badtype("slice", data) } - sliceLen := datav.Len() - if rv.IsNil() { - rv.Set(reflect.MakeSlice(rv.Type(), sliceLen, sliceLen)) + n := datav.Len() + if rv.IsNil() || rv.Cap() < n { + rv.Set(reflect.MakeSlice(rv.Type(), n, n)) } + rv.SetLen(n) return md.unifySliceArray(datav, rv) } diff --git a/decode_meta.go b/decode_meta.go index c8114453b..ef6f545fa 100644 --- a/decode_meta.go +++ b/decode_meta.go @@ -59,6 +59,29 @@ func (k Key) String() string { return strings.Join(k, ".") } +func (k Key) maybeQuotedAll() string { + var ss []string + for i := range k { + ss = append(ss, k.maybeQuoted(i)) + } + return strings.Join(ss, ".") +} + +func (k Key) maybeQuoted(i int) string { + quote := false + for _, c := range k[i] { + if !isBareKeyChar(c) { + quote = true + break + } + } + if quote { + return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\"" + } else { + return k[i] + } +} + func (k Key) add(piece string) Key { newKey := make(Key, len(k)+1) copy(newKey, k) diff --git a/decode_test.go b/decode_test.go index b940333dc..213e70dca 100644 --- a/decode_test.go +++ b/decode_test.go @@ -17,7 +17,7 @@ func TestDecodeSimple(t *testing.T) { age = 250 andrew = "gallant" kait = "brady" -now = 1987-07-05T05:45:00Z +now = 1987-07-05T05:45:00Z yesOrNo = true pi = 3.14 colors = [ @@ -67,7 +67,7 @@ cauchy = "cat 2" {"cyan", "magenta", "yellow", "black"}, }, My: map[string]cats{ - "Cats": cats{Plato: "cat 1", Cauchy: "cat 2"}, + "Cats": {Plato: "cat 1", Cauchy: "cat 2"}, }, } if !reflect.DeepEqual(val, answer) { @@ -119,6 +119,23 @@ func TestDecodeEmbedded(t *testing.T) { } } +func TestDecodeIgnoredFields(t *testing.T) { + type simple struct { + Number int `toml:"-"` + } + const input = ` +Number = 123 +- = 234 +` + var s simple + if _, err := Decode(input, &s); err != nil { + t.Fatal(err) + } + if s.Number != 0 { + t.Errorf("got: %d; want 0", s.Number) + } +} + func TestTableArrays(t *testing.T) { var tomlTableArrays = ` [[albums]] @@ -132,7 +149,7 @@ name = "Born to Run" [[albums]] name = "Born in the USA" - + [[albums.songs]] name = "Glory Days" @@ -285,6 +302,43 @@ Description = "da base" } } +func TestDecodeBadTimestamp(t *testing.T) { + var x struct { + T time.Time + } + for _, s := range []string{ + "T = 123", "T = 2006-01-50T00:00:00Z", "T = 2006-01-30T00:00:00", + } { + if _, err := Decode(s, &x); err == nil { + t.Errorf("Expected invalid DateTime error for %q", s) + } + } +} + +func TestDecodeMultilineStrings(t *testing.T) { + var x struct { + S string + } + const s0 = `s = """ +a b \n c +d e f +"""` + if _, err := Decode(s0, &x); err != nil { + t.Fatal(err) + } + if want := "a b \n c\nd e f\n"; x.S != want { + t.Errorf("got: %q; want: %q", x.S, want) + } + const s1 = `s = """a b c\ +"""` + if _, err := Decode(s1, &x); err != nil { + t.Fatal(err) + } + if want := "a b c"; x.S != want { + t.Errorf("got: %q; want: %q", x.S, want) + } +} + type sphere struct { Center [3]float64 Radius float64 @@ -349,6 +403,213 @@ func TestDecodeSizedInts(t *testing.T) { } } +func TestUnmarshaler(t *testing.T) { + + var tomlBlob = ` +[dishes.hamboogie] +name = "Hamboogie with fries" +price = 10.99 + +[[dishes.hamboogie.ingredients]] +name = "Bread Bun" + +[[dishes.hamboogie.ingredients]] +name = "Lettuce" + +[[dishes.hamboogie.ingredients]] +name = "Real Beef Patty" + +[[dishes.hamboogie.ingredients]] +name = "Tomato" + +[dishes.eggsalad] +name = "Egg Salad with rice" +price = 3.99 + +[[dishes.eggsalad.ingredients]] +name = "Egg" + +[[dishes.eggsalad.ingredients]] +name = "Mayo" + +[[dishes.eggsalad.ingredients]] +name = "Rice" +` + m := &menu{} + if _, err := Decode(tomlBlob, m); err != nil { + log.Fatal(err) + } + + if len(m.Dishes) != 2 { + t.Log("two dishes should be loaded with UnmarshalTOML()") + t.Errorf("expected %d but got %d", 2, len(m.Dishes)) + } + + eggSalad := m.Dishes["eggsalad"] + if _, ok := interface{}(eggSalad).(dish); !ok { + t.Errorf("expected a dish") + } + + if eggSalad.Name != "Egg Salad with rice" { + t.Errorf("expected the dish to be named 'Egg Salad with rice'") + } + + if len(eggSalad.Ingredients) != 3 { + t.Log("dish should be loaded with UnmarshalTOML()") + t.Errorf("expected %d but got %d", 3, len(eggSalad.Ingredients)) + } + + found := false + for _, i := range eggSalad.Ingredients { + if i.Name == "Rice" { + found = true + break + } + } + if !found { + t.Error("Rice was not loaded in UnmarshalTOML()") + } + + // test on a value - must be passed as * + o := menu{} + if _, err := Decode(tomlBlob, &o); err != nil { + log.Fatal(err) + } + +} + +type menu struct { + Dishes map[string]dish +} + +func (m *menu) UnmarshalTOML(p interface{}) error { + m.Dishes = make(map[string]dish) + data, _ := p.(map[string]interface{}) + dishes := data["dishes"].(map[string]interface{}) + for n, v := range dishes { + if d, ok := v.(map[string]interface{}); ok { + nd := dish{} + nd.UnmarshalTOML(d) + m.Dishes[n] = nd + } else { + return fmt.Errorf("not a dish") + } + } + return nil +} + +type dish struct { + Name string + Price float32 + Ingredients []ingredient +} + +func (d *dish) UnmarshalTOML(p interface{}) error { + data, _ := p.(map[string]interface{}) + d.Name, _ = data["name"].(string) + d.Price, _ = data["price"].(float32) + ingredients, _ := data["ingredients"].([]map[string]interface{}) + for _, e := range ingredients { + n, _ := interface{}(e).(map[string]interface{}) + name, _ := n["name"].(string) + i := ingredient{name} + d.Ingredients = append(d.Ingredients, i) + } + return nil +} + +type ingredient struct { + Name string +} + +func TestDecodeSlices(t *testing.T) { + type T struct { + S []string + } + for i, tt := range []struct { + v T + input string + want T + }{ + {T{}, "", T{}}, + {T{[]string{}}, "", T{[]string{}}}, + {T{[]string{"a", "b"}}, "", T{[]string{"a", "b"}}}, + {T{}, "S = []", T{[]string{}}}, + {T{[]string{}}, "S = []", T{[]string{}}}, + {T{[]string{"a", "b"}}, "S = []", T{[]string{}}}, + {T{}, `S = ["x"]`, T{[]string{"x"}}}, + {T{[]string{}}, `S = ["x"]`, T{[]string{"x"}}}, + {T{[]string{"a", "b"}}, `S = ["x"]`, T{[]string{"x"}}}, + } { + if _, err := Decode(tt.input, &tt.v); err != nil { + t.Errorf("[%d] %s", i, err) + continue + } + if !reflect.DeepEqual(tt.v, tt.want) { + t.Errorf("[%d] got %#v; want %#v", i, tt.v, tt.want) + } + } +} + +func TestDecodePrimitive(t *testing.T) { + type S struct { + P Primitive + } + type T struct { + S []int + } + slicep := func(s []int) *[]int { return &s } + arrayp := func(a [2]int) *[2]int { return &a } + mapp := func(m map[string]int) *map[string]int { return &m } + for i, tt := range []struct { + v interface{} + input string + want interface{} + }{ + // slices + {slicep(nil), "", slicep(nil)}, + {slicep([]int{}), "", slicep([]int{})}, + {slicep([]int{1, 2, 3}), "", slicep([]int{1, 2, 3})}, + {slicep(nil), "P = [1,2]", slicep([]int{1, 2})}, + {slicep([]int{}), "P = [1,2]", slicep([]int{1, 2})}, + {slicep([]int{1, 2, 3}), "P = [1,2]", slicep([]int{1, 2})}, + + // arrays + {arrayp([2]int{2, 3}), "", arrayp([2]int{2, 3})}, + {arrayp([2]int{2, 3}), "P = [3,4]", arrayp([2]int{3, 4})}, + + // maps + {mapp(nil), "", mapp(nil)}, + {mapp(map[string]int{}), "", mapp(map[string]int{})}, + {mapp(map[string]int{"a": 1}), "", mapp(map[string]int{"a": 1})}, + {mapp(nil), "[P]\na = 2", mapp(map[string]int{"a": 2})}, + {mapp(map[string]int{}), "[P]\na = 2", mapp(map[string]int{"a": 2})}, + {mapp(map[string]int{"a": 1, "b": 3}), "[P]\na = 2", mapp(map[string]int{"a": 2, "b": 3})}, + + // structs + {&T{nil}, "[P]", &T{nil}}, + {&T{[]int{}}, "[P]", &T{[]int{}}}, + {&T{[]int{1, 2, 3}}, "[P]", &T{[]int{1, 2, 3}}}, + {&T{nil}, "[P]\nS = [1,2]", &T{[]int{1, 2}}}, + {&T{[]int{}}, "[P]\nS = [1,2]", &T{[]int{1, 2}}}, + {&T{[]int{1, 2, 3}}, "[P]\nS = [1,2]", &T{[]int{1, 2}}}, + } { + var s S + md, err := Decode(tt.input, &s) + if err != nil { + t.Errorf("[%d] Decode error: %s", i, err) + continue + } + if err := md.PrimitiveDecode(s.P, tt.v); err != nil { + t.Errorf("[%d] PrimitiveDecode error: %s", i, err) + continue + } + if !reflect.DeepEqual(tt.v, tt.want) { + t.Errorf("[%d] got %#v; want %#v", i, tt.v, tt.want) + } + } +} + func ExampleMetaData_PrimitiveDecode() { var md MetaData var err error @@ -360,7 +621,7 @@ ranking = ["Springsteen", "J Geils"] started = 1973 albums = ["Greetings", "WIESS", "Born to Run", "Darkness"] -[bands.J Geils] +[bands."J Geils"] started = 1970 albums = ["The J. Geils Band", "Full House", "Blow Your Face Out"] ` @@ -434,7 +695,7 @@ ip = "10.0.0.2" } type server struct { - IP string `toml:"ip"` + IP string `toml:"ip,omitempty"` Config serverConfig `toml:"config"` } @@ -538,3 +799,294 @@ key3 = "value3" // Output: // Undecoded keys: ["key2"] } + +// Example UnmarshalTOML shows how to implement a struct type that knows how to +// unmarshal itself. The struct must take full responsibility for mapping the +// values passed into the struct. The method may be used with interfaces in a +// struct in cases where the actual type is not known until the data is +// examined. +func Example_unmarshalTOML() { + + var blob = ` +[[parts]] +type = "valve" +id = "valve-1" +size = 1.2 +rating = 4 + +[[parts]] +type = "valve" +id = "valve-2" +size = 2.1 +rating = 5 + +[[parts]] +type = "pipe" +id = "pipe-1" +length = 2.1 +diameter = 12 + +[[parts]] +type = "cable" +id = "cable-1" +length = 12 +rating = 3.1 +` + o := &order{} + err := Unmarshal([]byte(blob), o) + if err != nil { + log.Fatal(err) + } + + fmt.Println(len(o.parts)) + + for _, part := range o.parts { + fmt.Println(part.Name()) + } + + // Code to implement UmarshalJSON. + + // type order struct { + // // NOTE `order.parts` is a private slice of type `part` which is an + // // interface and may only be loaded from toml using the + // // UnmarshalTOML() method of the Umarshaler interface. + // parts parts + // } + + // func (o *order) UnmarshalTOML(data interface{}) error { + + // // NOTE the example below contains detailed type casting to show how + // // the 'data' is retrieved. In operational use, a type cast wrapper + // // may be prefered e.g. + // // + // // func AsMap(v interface{}) (map[string]interface{}, error) { + // // return v.(map[string]interface{}) + // // } + // // + // // resulting in: + // // d, _ := AsMap(data) + // // + + // d, _ := data.(map[string]interface{}) + // parts, _ := d["parts"].([]map[string]interface{}) + + // for _, p := range parts { + + // typ, _ := p["type"].(string) + // id, _ := p["id"].(string) + + // // detect the type of part and handle each case + // switch p["type"] { + // case "valve": + + // size := float32(p["size"].(float64)) + // rating := int(p["rating"].(int64)) + + // valve := &valve{ + // Type: typ, + // ID: id, + // Size: size, + // Rating: rating, + // } + + // o.parts = append(o.parts, valve) + + // case "pipe": + + // length := float32(p["length"].(float64)) + // diameter := int(p["diameter"].(int64)) + + // pipe := &pipe{ + // Type: typ, + // ID: id, + // Length: length, + // Diameter: diameter, + // } + + // o.parts = append(o.parts, pipe) + + // case "cable": + + // length := int(p["length"].(int64)) + // rating := float32(p["rating"].(float64)) + + // cable := &cable{ + // Type: typ, + // ID: id, + // Length: length, + // Rating: rating, + // } + + // o.parts = append(o.parts, cable) + + // } + // } + + // return nil + // } + + // type parts []part + + // type part interface { + // Name() string + // } + + // type valve struct { + // Type string + // ID string + // Size float32 + // Rating int + // } + + // func (v *valve) Name() string { + // return fmt.Sprintf("VALVE: %s", v.ID) + // } + + // type pipe struct { + // Type string + // ID string + // Length float32 + // Diameter int + // } + + // func (p *pipe) Name() string { + // return fmt.Sprintf("PIPE: %s", p.ID) + // } + + // type cable struct { + // Type string + // ID string + // Length int + // Rating float32 + // } + + // func (c *cable) Name() string { + // return fmt.Sprintf("CABLE: %s", c.ID) + // } + + // Output: + // 4 + // VALVE: valve-1 + // VALVE: valve-2 + // PIPE: pipe-1 + // CABLE: cable-1 + +} + +type order struct { + // NOTE `order.parts` is a private slice of type `part` which is an + // interface and may only be loaded from toml using the UnmarshalTOML() + // method of the Umarshaler interface. + parts parts +} + +func (o *order) UnmarshalTOML(data interface{}) error { + + // NOTE the example below contains detailed type casting to show how + // the 'data' is retrieved. In operational use, a type cast wrapper + // may be prefered e.g. + // + // func AsMap(v interface{}) (map[string]interface{}, error) { + // return v.(map[string]interface{}) + // } + // + // resulting in: + // d, _ := AsMap(data) + // + + d, _ := data.(map[string]interface{}) + parts, _ := d["parts"].([]map[string]interface{}) + + for _, p := range parts { + + typ, _ := p["type"].(string) + id, _ := p["id"].(string) + + // detect the type of part and handle each case + switch p["type"] { + case "valve": + + size := float32(p["size"].(float64)) + rating := int(p["rating"].(int64)) + + valve := &valve{ + Type: typ, + ID: id, + Size: size, + Rating: rating, + } + + o.parts = append(o.parts, valve) + + case "pipe": + + length := float32(p["length"].(float64)) + diameter := int(p["diameter"].(int64)) + + pipe := &pipe{ + Type: typ, + ID: id, + Length: length, + Diameter: diameter, + } + + o.parts = append(o.parts, pipe) + + case "cable": + + length := int(p["length"].(int64)) + rating := float32(p["rating"].(float64)) + + cable := &cable{ + Type: typ, + ID: id, + Length: length, + Rating: rating, + } + + o.parts = append(o.parts, cable) + + } + } + + return nil +} + +type parts []part + +type part interface { + Name() string +} + +type valve struct { + Type string + ID string + Size float32 + Rating int +} + +func (v *valve) Name() string { + return fmt.Sprintf("VALVE: %s", v.ID) +} + +type pipe struct { + Type string + ID string + Length float32 + Diameter int +} + +func (p *pipe) Name() string { + return fmt.Sprintf("PIPE: %s", p.ID) +} + +type cable struct { + Type string + ID string + Length int + Rating float32 +} + +func (c *cable) Name() string { + return fmt.Sprintf("CABLE: %s", c.ID) +} diff --git a/encode.go b/encode.go index 361871347..4e4c97aed 100644 --- a/encode.go +++ b/encode.go @@ -118,7 +118,8 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) { k := rv.Kind() switch k { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: @@ -173,7 +174,8 @@ func (enc *Encoder) eElement(rv reflect.Value) { switch rv.Kind() { case reflect.Bool: enc.wf(strconv.FormatBool(rv.Bool())) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64: enc.wf(strconv.FormatInt(rv.Int(), 10)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: @@ -223,28 +225,28 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { if len(key) == 0 { encPanic(errNoKey) } - panicIfInvalidKey(key, true) for i := 0; i < rv.Len(); i++ { trv := rv.Index(i) if isNil(trv) { continue } + panicIfInvalidKey(key) enc.newline() - enc.wf("%s[[%s]]", enc.indentStr(key), key.String()) + enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll()) enc.newline() enc.eMapOrStruct(key, trv) } } func (enc *Encoder) eTable(key Key, rv reflect.Value) { + panicIfInvalidKey(key) if len(key) == 1 { // Output an extra new line between top-level tables. // (The newline isn't written if nothing else has been written though.) enc.newline() } if len(key) > 0 { - panicIfInvalidKey(key, true) - enc.wf("%s[%s]", enc.indentStr(key), key.String()) + enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll()) enc.newline() } enc.eMapOrStruct(key, rv) @@ -304,19 +306,30 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) { addFields = func(rt reflect.Type, rv reflect.Value, start []int) { for i := 0; i < rt.NumField(); i++ { f := rt.Field(i) - // skip unexporded fields - if f.PkgPath != "" { + // skip unexported fields + if f.PkgPath != "" && !f.Anonymous { continue } frv := rv.Field(i) if f.Anonymous { - frv := eindirect(frv) - t := frv.Type() - if t.Kind() != reflect.Struct { - encPanic(errAnonNonStruct) + t := f.Type + switch t.Kind() { + case reflect.Struct: + addFields(t, frv, f.Index) + continue + case reflect.Ptr: + if t.Elem().Kind() == reflect.Struct { + if !frv.IsNil() { + addFields(t.Elem(), frv.Elem(), f.Index) + } + continue + } + // Fall through to the normal field encoding logic below + // for non-struct anonymous fields. } - addFields(t, frv, f.Index) - } else if typeIsHash(tomlTypeOfGo(frv)) { + } + + if typeIsHash(tomlTypeOfGo(frv)) { fieldsSub = append(fieldsSub, append(start, f.Index...)) } else { fieldsDirect = append(fieldsDirect, append(start, f.Index...)) @@ -334,13 +347,20 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) { continue } - keyName := sft.Tag.Get("toml") - if keyName == "-" { + tag := sft.Tag.Get("toml") + if tag == "-" { continue } + keyName, opts := getOptions(tag) if keyName == "" { keyName = sft.Name } + if _, ok := opts["omitempty"]; ok && isEmpty(sf) { + continue + } else if _, ok := opts["omitzero"]; ok && isZero(sf) { + continue + } + enc.encode(key.add(keyName), sf) } } @@ -348,10 +368,10 @@ func (enc *Encoder) eStruct(key Key, rv reflect.Value) { writeFields(fieldsSub) } -// tomlTypeName returns the TOML type name of the Go value's type. It is used to -// determine whether the types of array elements are mixed (which is forbidden). -// If the Go value is nil, then it is illegal for it to be an array element, and -// valueIsNil is returned as true. +// tomlTypeName returns the TOML type name of the Go value's type. It is +// used to determine whether the types of array elements are mixed (which is +// forbidden). If the Go value is nil, then it is illegal for it to be an array +// element, and valueIsNil is returned as true. // Returns the TOML type of a Go value. The type may be `nil`, which means // no concrete TOML type could be found. @@ -362,7 +382,8 @@ func tomlTypeOfGo(rv reflect.Value) tomlType { switch rv.Kind() { case reflect.Bool: return tomlBool - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, + reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return tomlInteger @@ -430,6 +451,41 @@ func tomlArrayType(rv reflect.Value) tomlType { return firstType } +func getOptions(keyName string) (string, map[string]struct{}) { + opts := make(map[string]struct{}) + ss := strings.Split(keyName, ",") + name := ss[0] + if len(ss) > 1 { + for _, opt := range ss { + opts[opt] = struct{}{} + } + } + + return name, opts +} + +func isZero(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint() == 0 + case reflect.Float32, reflect.Float64: + return rv.Float() == 0.0 + } + return false +} + +func isEmpty(rv reflect.Value) bool { + switch rv.Kind() { + case reflect.Array, reflect.Slice, reflect.Map, reflect.String: + return rv.Len() == 0 + case reflect.Bool: + return !rv.Bool() + } + return false +} + func (enc *Encoder) newline() { if enc.hasWritten { enc.wf("\n") @@ -440,8 +496,8 @@ func (enc *Encoder) keyEqElement(key Key, val reflect.Value) { if len(key) == 0 { encPanic(errNoKey) } - panicIfInvalidKey(key, false) - enc.wf("%s%s = ", enc.indentStr(key), key[len(key)-1]) + panicIfInvalidKey(key) + enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) enc.eElement(val) enc.newline() } @@ -479,37 +535,15 @@ func isNil(rv reflect.Value) bool { } } -func panicIfInvalidKey(key Key, hash bool) { - if hash { - for _, k := range key { - if !isValidTableName(k) { - encPanic(e("Key '%s' is not a valid table name. Table names "+ - "cannot contain '[', ']' or '.'.", key.String())) - } - } - } else { - if !isValidKeyName(key[len(key)-1]) { - encPanic(e("Key '%s' is not a name. Key names "+ - "cannot contain whitespace.", key.String())) - } - } -} - -func isValidTableName(s string) bool { - if len(s) == 0 { - return false - } - for _, r := range s { - if r == '[' || r == ']' || r == '.' { - return false +func panicIfInvalidKey(key Key) { + for _, k := range key { + if len(k) == 0 { + encPanic(e("Key '%s' is not a valid table name. Key names "+ + "cannot be empty.", key.maybeQuotedAll())) } } - return true } func isValidKeyName(s string) bool { - if len(s) == 0 { - return false - } - return true + return len(s) != 0 } diff --git a/encode_test.go b/encode_test.go index 74a5ee5d2..ef7acdd74 100644 --- a/encode_test.go +++ b/encode_test.go @@ -336,6 +336,10 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] }{struct{ *Embedded }{&Embedded{1}}}, wantOutput: "[_struct]\n _int = 1\n", }, + "embedded non-struct": { + input: struct{ NonStruct }{5}, + wantOutput: "NonStruct = 5\n", + }, "array of tables": { input: struct { Structs []*struct{ Int int } `toml:"struct"` @@ -349,7 +353,7 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] "map": map[string]interface{}{ "zero": 5, "arr": []map[string]int{ - map[string]int{ + { "friend": 5, }, }, @@ -373,10 +377,6 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] input: map[int]string{1: ""}, wantError: errNonString, }, - "(error) anonymous non-struct": { - input: struct{ NonStruct }{5}, - wantError: errAnonNonStruct, - }, "(error) empty key name": { input: map[string]int{"": 1}, wantError: errAnything, @@ -455,6 +455,90 @@ func TestEncodeArrayHashWithNormalHashOrder(t *testing.T) { encodeExpected(t, "array hash with normal hash order", val, expected, nil) } +func TestEncodeWithOmitEmpty(t *testing.T) { + type simple struct { + Bool bool `toml:"bool,omitempty"` + String string `toml:"string,omitempty"` + Array [0]byte `toml:"array,omitempty"` + Slice []int `toml:"slice,omitempty"` + Map map[string]string `toml:"map,omitempty"` + } + + var v simple + encodeExpected(t, "fields with omitempty are omitted when empty", v, "", nil) + v = simple{ + Bool: true, + String: " ", + Slice: []int{2, 3, 4}, + Map: map[string]string{"foo": "bar"}, + } + expected := `bool = true +string = " " +slice = [2, 3, 4] + +[map] + foo = "bar" +` + encodeExpected(t, "fields with omitempty are not omitted when non-empty", + v, expected, nil) +} + +func TestEncodeWithOmitZero(t *testing.T) { + type simple struct { + Number int `toml:"number,omitzero"` + Real float64 `toml:"real,omitzero"` + Unsigned uint `toml:"unsigned,omitzero"` + } + + value := simple{0, 0.0, uint(0)} + expected := "" + + encodeExpected(t, "simple with omitzero, all zero", value, expected, nil) + + value.Number = 10 + value.Real = 20 + value.Unsigned = 5 + expected = `number = 10 +real = 20.0 +unsigned = 5 +` + encodeExpected(t, "simple with omitzero, non-zero", value, expected, nil) +} + +func TestEncodeOmitemptyWithEmptyName(t *testing.T) { + type simple struct { + S []int `toml:",omitempty"` + } + v := simple{[]int{1, 2, 3}} + expected := "S = [1, 2, 3]\n" + encodeExpected(t, "simple with omitempty, no name, non-empty field", + v, expected, nil) +} + +func TestEncodeAnonymousStructPointerField(t *testing.T) { + type Sub struct{} + type simple struct { + *Sub + } + + value := simple{} + expected := "" + encodeExpected(t, "nil anonymous struct pointer field", value, expected, nil) + + value = simple{Sub: &Sub{}} + expected = "" + encodeExpected(t, "non-nil anonymous struct pointer field", value, expected, nil) +} + +func TestEncodeIgnoredFields(t *testing.T) { + type simple struct { + Number int `toml:"-"` + } + value := simple{} + expected := "" + encodeExpected(t, "ignored field", value, expected, nil) +} + func encodeExpected( t *testing.T, label string, val interface{}, wantStr string, wantErr error, ) { diff --git a/encoding_types.go b/encoding_types.go index 140c44c11..d36e1dd60 100644 --- a/encoding_types.go +++ b/encoding_types.go @@ -14,6 +14,6 @@ import ( // so that Go 1.1 can be supported. type TextMarshaler encoding.TextMarshaler -// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined here -// so that Go 1.1 can be supported. +// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined +// here so that Go 1.1 can be supported. type TextUnmarshaler encoding.TextUnmarshaler diff --git a/encoding_types_1.1.go b/encoding_types_1.1.go index fb285e7f5..e8d503d04 100644 --- a/encoding_types_1.1.go +++ b/encoding_types_1.1.go @@ -11,8 +11,8 @@ type TextMarshaler interface { MarshalText() (text []byte, err error) } -// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined here -// so that Go 1.1 can be supported. +// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined +// here so that Go 1.1 can be supported. type TextUnmarshaler interface { UnmarshalText(text []byte) error } diff --git a/lex.go b/lex.go index 3821fa271..9b20b3a81 100644 --- a/lex.go +++ b/lex.go @@ -14,6 +14,9 @@ const ( itemEOF itemText itemString + itemRawString + itemMultilineString + itemRawMultilineString itemBool itemInteger itemFloat @@ -42,6 +45,8 @@ const ( commentStart = '#' stringStart = '"' stringEnd = '"' + rawStringStart = '\'' + rawStringEnd = '\'' ) type stateFn func(lx *lexer) stateFn @@ -256,38 +261,52 @@ func lexArrayTableEnd(lx *lexer) stateFn { } func lexTableNameStart(lx *lexer) stateFn { - switch lx.next() { - case tableEnd, eof: - return lx.errorf("Unexpected end of table. (Tables cannot " + + switch r := lx.peek(); { + case r == tableEnd || r == eof: + return lx.errorf("Unexpected end of table name. (Table names cannot " + "be empty.)") - case tableSep: - return lx.errorf("Unexpected table separator. (Tables cannot " + + case r == tableSep: + return lx.errorf("Unexpected table separator. (Table names cannot " + "be empty.)") + case r == stringStart || r == rawStringStart: + lx.ignore() + lx.push(lexTableNameEnd) + return lexValue // reuse string lexing + default: + return lexBareTableName } - return lexTableName } // lexTableName lexes the name of a table. It assumes that at least one // valid character for the table has already been read. -func lexTableName(lx *lexer) stateFn { - switch lx.peek() { - case eof: - return lx.errorf("Unexpected end of table name %q.", lx.current()) - case tableStart: - return lx.errorf("Table names cannot contain %q or %q.", - tableStart, tableEnd) - case tableEnd: - lx.emit(itemText) - lx.next() - return lx.pop() - case tableSep: - lx.emit(itemText) - lx.next() +func lexBareTableName(lx *lexer) stateFn { + switch r := lx.next(); { + case isBareKeyChar(r): + return lexBareTableName + case r == tableSep || r == tableEnd: + lx.backup() + lx.emitTrim(itemText) + return lexTableNameEnd + default: + return lx.errorf("Bare keys cannot contain %q.", r) + } +} + +// lexTableNameEnd reads the end of a piece of a table name, optionally +// consuming whitespace. +func lexTableNameEnd(lx *lexer) stateFn { + switch r := lx.next(); { + case isWhitespace(r): + return lexTableNameEnd + case r == tableSep: lx.ignore() return lexTableNameStart + case r == tableEnd: + return lx.pop() + default: + return lx.errorf("Expected '.' or ']' to end table name, but got %q "+ + "instead.", r) } - lx.next() - return lexTableName } // lexKeyStart consumes a key name up until the first non-whitespace character. @@ -300,53 +319,48 @@ func lexKeyStart(lx *lexer) stateFn { case isWhitespace(r) || isNL(r): lx.next() return lexSkip(lx, lexKeyStart) + case r == stringStart || r == rawStringStart: + lx.ignore() + lx.emit(itemKeyStart) + lx.push(lexKeyEnd) + return lexValue // reuse string lexing + default: + lx.ignore() + lx.emit(itemKeyStart) + return lexBareKey } - - lx.ignore() - lx.emit(itemKeyStart) - lx.next() - return lexKey } -// lexKey consumes the text of a key. Assumes that the first character (which -// is not whitespace) has already been consumed. -func lexKey(lx *lexer) stateFn { - r := lx.peek() - - // Keys cannot contain a '#' character. - if r == commentStart { - return lx.errorf("Key cannot contain a '#' character.") - } - - // XXX: Possible divergence from spec? - // "Keys start with the first non-whitespace character and end with the - // last non-whitespace character before the equals sign." - // Note here that whitespace is either a tab or a space. - // But we'll call it quits if we see a new line too. - if isNL(r) { +// lexBareKey consumes the text of a bare key. Assumes that the first character +// (which is not whitespace) has not yet been consumed. +func lexBareKey(lx *lexer) stateFn { + switch r := lx.next(); { + case isBareKeyChar(r): + return lexBareKey + case isWhitespace(r): lx.emitTrim(itemText) return lexKeyEnd - } - - // Let's also call it quits if we see an equals sign. - if r == keySep { + case r == keySep: + lx.backup() lx.emitTrim(itemText) return lexKeyEnd + default: + return lx.errorf("Bare keys cannot contain %q.", r) } - - lx.next() - return lexKey } -// lexKeyEnd consumes the end of a key (up to the key separator). -// Assumes that any whitespace after a key has been consumed. +// lexKeyEnd consumes the end of a key and trims whitespace (up to the key +// separator). func lexKeyEnd(lx *lexer) stateFn { - r := lx.next() - if r == keySep { + switch r := lx.next(); { + case r == keySep: return lexSkip(lx, lexValue) + case isWhitespace(r): + return lexSkip(lx, lexKeyEnd) + default: + return lx.errorf("Expected key separator %q, but got %q instead.", + keySep, r) } - return lx.errorf("Expected key separator %q, but got %q instead.", - keySep, r) } // lexValue starts the consumption of a value anywhere a value is expected. @@ -354,7 +368,8 @@ func lexKeyEnd(lx *lexer) stateFn { // After a value is lexed, the last state on the next is popped and returned. func lexValue(lx *lexer) stateFn { // We allow whitespace to precede a value, but NOT new lines. - // In array syntax, the array states are responsible for ignoring new lines. + // In array syntax, the array states are responsible for ignoring new + // lines. r := lx.next() if isWhitespace(r) { return lexSkip(lx, lexValue) @@ -366,8 +381,25 @@ func lexValue(lx *lexer) stateFn { lx.emit(itemArray) return lexArrayValue case r == stringStart: + if lx.accept(stringStart) { + if lx.accept(stringStart) { + lx.ignore() // Ignore """ + return lexMultilineString + } + lx.backup() + } lx.ignore() // ignore the '"' return lexString + case r == rawStringStart: + if lx.accept(rawStringStart) { + if lx.accept(rawStringStart) { + lx.ignore() // Ignore """ + return lexMultilineRawString + } + lx.backup() + } + lx.ignore() // ignore the "'" + return lexRawString case r == 't': return lexTrue case r == 'f': @@ -441,6 +473,7 @@ func lexString(lx *lexer) stateFn { case isNL(r): return lx.errorf("Strings cannot contain new lines.") case r == '\\': + lx.push(lexString) return lexStringEscape case r == stringEnd: lx.backup() @@ -452,8 +485,87 @@ func lexString(lx *lexer) stateFn { return lexString } -// lexStringEscape consumes an escaped character. It assumes that the preceding -// '\\' has already been consumed. +// lexMultilineString consumes the inner contents of a string. It assumes that +// the beginning '"""' has already been consumed and ignored. +func lexMultilineString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == '\\': + return lexMultilineStringEscape + case r == stringEnd: + if lx.accept(stringEnd) { + if lx.accept(stringEnd) { + lx.backup() + lx.backup() + lx.backup() + lx.emit(itemMultilineString) + lx.next() + lx.next() + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + } + return lexMultilineString +} + +// lexRawString consumes a raw string. Nothing can be escaped in such a string. +// It assumes that the beginning "'" has already been consumed and ignored. +func lexRawString(lx *lexer) stateFn { + r := lx.next() + switch { + case isNL(r): + return lx.errorf("Strings cannot contain new lines.") + case r == rawStringEnd: + lx.backup() + lx.emit(itemRawString) + lx.next() + lx.ignore() + return lx.pop() + } + return lexRawString +} + +// lexMultilineRawString consumes a raw string. Nothing can be escaped in such +// a string. It assumes that the beginning "'" has already been consumed and +// ignored. +func lexMultilineRawString(lx *lexer) stateFn { + r := lx.next() + switch { + case r == rawStringEnd: + if lx.accept(rawStringEnd) { + if lx.accept(rawStringEnd) { + lx.backup() + lx.backup() + lx.backup() + lx.emit(itemRawMultilineString) + lx.next() + lx.next() + lx.next() + lx.ignore() + return lx.pop() + } + lx.backup() + } + } + return lexMultilineRawString +} + +// lexMultilineStringEscape consumes an escaped character. It assumes that the +// preceding '\\' has already been consumed. +func lexMultilineStringEscape(lx *lexer) stateFn { + // Handle the special case first: + if isNL(lx.next()) { + return lexMultilineString + } else { + lx.backup() + lx.push(lexMultilineString) + return lexStringEscape(lx) + } +} + func lexStringEscape(lx *lexer) stateFn { r := lx.next() switch r { @@ -469,35 +581,45 @@ func lexStringEscape(lx *lexer) stateFn { fallthrough case '"': fallthrough - case '/': - fallthrough case '\\': - return lexString + return lx.pop() case 'u': - return lexStringUnicode + return lexShortUnicodeEscape + case 'U': + return lexLongUnicodeEscape } return lx.errorf("Invalid escape character %q. Only the following "+ "escape characters are allowed: "+ - "\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, and \\uXXXX.", r) + "\\b, \\t, \\n, \\f, \\r, \\\", \\/, \\\\, "+ + "\\uXXXX and \\UXXXXXXXX.", r) } -// lexStringBinary consumes two hexadecimal digits following '\x'. It assumes -// that the '\x' has already been consumed. -func lexStringUnicode(lx *lexer) stateFn { +func lexShortUnicodeEscape(lx *lexer) stateFn { var r rune - for i := 0; i < 4; i++ { r = lx.next() if !isHexadecimal(r) { - return lx.errorf("Expected four hexadecimal digits after '\\x', "+ + return lx.errorf("Expected four hexadecimal digits after '\\u', "+ "but got '%s' instead.", lx.current()) } } - return lexString + return lx.pop() +} + +func lexLongUnicodeEscape(lx *lexer) stateFn { + var r rune + for i := 0; i < 8; i++ { + r = lx.next() + if !isHexadecimal(r) { + return lx.errorf("Expected eight hexadecimal digits after '\\U', "+ + "but got '%s' instead.", lx.current()) + } + } + return lx.pop() } -// lexNumberOrDateStart consumes either a (positive) integer, float or datetime. -// It assumes that NO negative sign has been consumed. +// lexNumberOrDateStart consumes either a (positive) integer, float or +// datetime. It assumes that NO negative sign has been consumed. func lexNumberOrDateStart(lx *lexer) stateFn { r := lx.next() if !isDigit(r) { @@ -557,9 +679,10 @@ func lexDateAfterYear(lx *lexer) stateFn { return lx.pop() } -// lexNumberStart consumes either an integer or a float. It assumes that a -// negative sign has already been read, but that *no* digits have been consumed. -// lexNumberStart will move to the appropriate integer or float states. +// lexNumberStart consumes either an integer or a float. It assumes that +// a negative sign has already been read, but that *no* digits have been +// consumed. lexNumberStart will move to the appropriate integer or float +// states. func lexNumberStart(lx *lexer) stateFn { // we MUST see a digit. Even floats have to start with a digit. r := lx.next() @@ -693,6 +816,14 @@ func isHexadecimal(r rune) bool { (r >= 'A' && r <= 'F') } +func isBareKeyChar(r rune) bool { + return (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') || + (r >= '0' && r <= '9') || + r == '_' || + r == '-' +} + func (itype itemType) String() string { switch itype { case itemError: @@ -705,6 +836,12 @@ func (itype itemType) String() string { return "Text" case itemString: return "String" + case itemRawString: + return "String" + case itemMultilineString: + return "String" + case itemRawMultilineString: + return "String" case itemBool: return "Bool" case itemInteger: diff --git a/parse.go b/parse.go index 43afe3c3f..6a82e84f6 100644 --- a/parse.go +++ b/parse.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" "time" + "unicode" "unicode/utf8" ) @@ -66,7 +67,7 @@ func parse(data string) (p *parser, err error) { } func (p *parser) panicf(format string, v ...interface{}) { - msg := fmt.Sprintf("Near line %d, key '%s': %s", + msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s", p.approxLine, p.current(), fmt.Sprintf(format, v...)) panic(parseError(msg)) } @@ -74,13 +75,13 @@ func (p *parser) panicf(format string, v ...interface{}) { func (p *parser) next() item { it := p.lx.nextItem() if it.typ == itemError { - p.panicf("Near line %d: %s", it.line, it.val) + p.panicf("%s", it.val) } return it } func (p *parser) bug(format string, v ...interface{}) { - log.Fatalf("BUG: %s\n\n", fmt.Sprintf(format, v...)) + log.Panicf("BUG: %s\n\n", fmt.Sprintf(format, v...)) } func (p *parser) expect(typ itemType) item { @@ -101,12 +102,12 @@ func (p *parser) topLevel(item item) { p.approxLine = item.line p.expect(itemText) case itemTableStart: - kg := p.expect(itemText) + kg := p.next() p.approxLine = kg.line - key := make(Key, 0) - for ; kg.typ == itemText; kg = p.next() { - key = append(key, kg.val) + var key Key + for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() { + key = append(key, p.keyString(kg)) } p.assertEqual(itemTableEnd, kg.typ) @@ -114,12 +115,12 @@ func (p *parser) topLevel(item item) { p.setType("", tomlHash) p.ordered = append(p.ordered, key) case itemArrayTableStart: - kg := p.expect(itemText) + kg := p.next() p.approxLine = kg.line - key := make(Key, 0) - for ; kg.typ == itemText; kg = p.next() { - key = append(key, kg.val) + var key Key + for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() { + key = append(key, p.keyString(kg)) } p.assertEqual(itemArrayTableEnd, kg.typ) @@ -127,27 +128,48 @@ func (p *parser) topLevel(item item) { p.setType("", tomlArrayHash) p.ordered = append(p.ordered, key) case itemKeyStart: - kname := p.expect(itemText) - p.currentKey = kname.val + kname := p.next() p.approxLine = kname.line + p.currentKey = p.keyString(kname) val, typ := p.value(p.next()) p.setValue(p.currentKey, val) p.setType(p.currentKey, typ) p.ordered = append(p.ordered, p.context.add(p.currentKey)) - p.currentKey = "" default: p.bug("Unexpected type at top level: %s", item.typ) } } +// Gets a string for a key (or part of a key in a table name). +func (p *parser) keyString(it item) string { + switch it.typ { + case itemText: + return it.val + case itemString, itemMultilineString, + itemRawString, itemRawMultilineString: + s, _ := p.value(it) + return s.(string) + default: + p.bug("Unexpected key type: %s", it.typ) + panic("unreachable") + } +} + // value translates an expected value from the lexer into a Go value wrapped // as an empty interface. func (p *parser) value(it item) (interface{}, tomlType) { switch it.typ { case itemString: - return p.replaceUnicode(replaceEscapes(it.val)), p.typeOfPrimitive(it) + return p.replaceEscapes(it.val), p.typeOfPrimitive(it) + case itemMultilineString: + trimmed := stripFirstNewline(stripEscapedWhitespace(it.val)) + return p.replaceEscapes(trimmed), p.typeOfPrimitive(it) + case itemRawString: + return it.val, p.typeOfPrimitive(it) + case itemRawMultilineString: + return stripFirstNewline(it.val), p.typeOfPrimitive(it) case itemBool: switch it.val { case "true": @@ -194,7 +216,7 @@ func (p *parser) value(it item) (interface{}, tomlType) { case itemDatetime: t, err := time.Parse("2006-01-02T15:04:05Z", it.val) if err != nil { - p.bug("Expected Zulu formatted DateTime, but got '%s'.", it.val) + p.panicf("Invalid RFC3339 Zulu DateTime: '%s'.", it.val) } return t, p.typeOfPrimitive(it) case itemArray: @@ -352,7 +374,8 @@ func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = true } -// removeImplicit stops tagging the given key as having been implicitly created. +// removeImplicit stops tagging the given key as having been implicitly +// created. func (p *parser) removeImplicit(key Key) { p.implicits[key.String()] = false } @@ -374,44 +397,97 @@ func (p *parser) current() string { return fmt.Sprintf("%s.%s", p.context, p.currentKey) } -func replaceEscapes(s string) string { - return strings.NewReplacer( - "\\b", "\u0008", - "\\t", "\u0009", - "\\n", "\u000A", - "\\f", "\u000C", - "\\r", "\u000D", - "\\\"", "\u0022", - "\\/", "\u002F", - "\\\\", "\u005C", - ).Replace(s) +func stripFirstNewline(s string) string { + if len(s) == 0 || s[0] != '\n' { + return s + } + return s[1:] } -func (p *parser) replaceUnicode(s string) string { - indexEsc := func() int { - return strings.Index(s, "\\u") +func stripEscapedWhitespace(s string) string { + esc := strings.Split(s, "\\\n") + if len(esc) > 1 { + for i := 1; i < len(esc); i++ { + esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace) + } } - for i := indexEsc(); i != -1; i = indexEsc() { - asciiBytes := s[i+2 : i+6] - s = strings.Replace(s, s[i:i+6], p.asciiEscapeToUnicode(asciiBytes), -1) + return strings.Join(esc, "") +} + +func (p *parser) replaceEscapes(str string) string { + var replaced []rune + s := []byte(str) + r := 0 + for r < len(s) { + if s[r] != '\\' { + c, size := utf8.DecodeRune(s[r:]) + r += size + replaced = append(replaced, c) + continue + } + r += 1 + if r >= len(s) { + p.bug("Escape sequence at end of string.") + return "" + } + switch s[r] { + default: + p.bug("Expected valid escape code after \\, but got %q.", s[r]) + return "" + case 'b': + replaced = append(replaced, rune(0x0008)) + r += 1 + case 't': + replaced = append(replaced, rune(0x0009)) + r += 1 + case 'n': + replaced = append(replaced, rune(0x000A)) + r += 1 + case 'f': + replaced = append(replaced, rune(0x000C)) + r += 1 + case 'r': + replaced = append(replaced, rune(0x000D)) + r += 1 + case '"': + replaced = append(replaced, rune(0x0022)) + r += 1 + case '\\': + replaced = append(replaced, rune(0x005C)) + r += 1 + case 'u': + // At this point, we know we have a Unicode escape of the form + // `uXXXX` at [r, r+5). (Because the lexer guarantees this + // for us.) + escaped := p.asciiEscapeToUnicode(s[r+1 : r+5]) + replaced = append(replaced, escaped) + r += 5 + case 'U': + // At this point, we know we have a Unicode escape of the form + // `uXXXX` at [r, r+9). (Because the lexer guarantees this + // for us.) + escaped := p.asciiEscapeToUnicode(s[r+1 : r+9]) + replaced = append(replaced, escaped) + r += 9 + } } - return s + return string(replaced) } -func (p *parser) asciiEscapeToUnicode(s string) string { +func (p *parser) asciiEscapeToUnicode(bs []byte) rune { + s := string(bs) hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32) if err != nil { p.bug("Could not parse '%s' as a hexadecimal number, but the "+ "lexer claims it's OK: %s", s, err) } - - // BUG(burntsushi) - // I honestly don't understand how this works. I can't seem - // to find a way to make this fail. I figured this would fail on invalid - // UTF-8 characters like U+DCFF, but it doesn't. - r := string(rune(hex)) - if !utf8.ValidString(r) { + if !utf8.ValidRune(rune(hex)) { p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s) } - return string(r) + return rune(hex) +} + +func isStringType(ty itemType) bool { + return ty == itemString || ty == itemMultilineString || + ty == itemRawString || ty == itemRawMultilineString } diff --git a/type_check.go b/type_check.go index 79dac6b19..c73f8afc1 100644 --- a/type_check.go +++ b/type_check.go @@ -56,6 +56,12 @@ func (p *parser) typeOfPrimitive(lexItem item) tomlType { return tomlDatetime case itemString: return tomlString + case itemMultilineString: + return tomlString + case itemRawString: + return tomlString + case itemRawMultilineString: + return tomlString case itemBool: return tomlBool } @@ -77,8 +83,8 @@ func (p *parser) typeOfArray(types []tomlType) tomlType { theType := types[0] for _, t := range types[1:] { if !typeEqual(theType, t) { - p.panicf("Array contains values of type '%s' and '%s', but arrays "+ - "must be homogeneous.", theType, t) + p.panicf("Array contains values of type '%s' and '%s', but "+ + "arrays must be homogeneous.", theType, t) } } return tomlArray diff --git a/type_fields.go b/type_fields.go index 7592f87a4..6da608af4 100644 --- a/type_fields.go +++ b/type_fields.go @@ -92,10 +92,10 @@ func typeFields(t reflect.Type) []field { // Scan f.typ for fields to include. for i := 0; i < f.typ.NumField(); i++ { sf := f.typ.Field(i) - if sf.PkgPath != "" { // unexported + if sf.PkgPath != "" && !sf.Anonymous { // unexported continue } - name := sf.Tag.Get("toml") + name, _ := getOptions(sf.Tag.Get("toml")) if name == "-" { continue }