Skip to content

Commit

Permalink
Eagerly raise an error if parse_float produces illegal types (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
hukkin committed Feb 3, 2022
1 parent 0eaf93d commit 794c8e5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Note that `decimal.Decimal` can be replaced with another callable that converts
The `decimal.Decimal` is, however, a practical choice for use cases where float inaccuracies can not be tolerated.

Illegal types are `dict` and `list`, and their subtypes.
Parsing floats into an illegal type results in undefined behavior.
A `ValueError` will be raised if `parse_float` produces illegal types.

## FAQ<a name="faq"></a>

Expand Down
22 changes: 22 additions & 0 deletions src/tomli/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]: # no
pos = 0
out = Output(NestedDict(), Flags())
header: Key = ()
parse_float = make_safe_parse_float(parse_float)

# Parse one statement at a time
# (typically means one line in TOML source)
Expand Down Expand Up @@ -667,3 +668,24 @@ def coord_repr(src: str, pos: Pos) -> str:

def is_unicode_scalar_value(codepoint: int) -> bool:
return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111)


def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat:
"""A decorator to make `parse_float` safe.
`parse_float` must not return dicts or lists, because these types
would be mixed with parsed TOML tables and arrays, thus confusing
the parser. The returned decorated callable raises `ValueError`
instead of returning illegal types.
"""
# The default `float` callable never returns illegal types. Optimize it.
if parse_float is float: # type: ignore[comparison-overlap]
return float

def safe_parse_float(float_str: str) -> Any:
float_value = parse_float(float_str)
if isinstance(float_value, (dict, list)):
raise ValueError("parse_float must not return dicts or lists")
return float_value

return safe_parse_float
14 changes: 14 additions & 0 deletions tests/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,17 @@ def test_invalid_char_quotes(self):

def test_module_name(self):
self.assertEqual(tomllib.TOMLDecodeError().__module__, tomllib.__name__)

def test_invalid_parse_float(self):
def dict_returner(s: str) -> dict:
return {}

def list_returner(s: str) -> list:
return []

for invalid_parse_float in (dict_returner, list_returner):
with self.assertRaises(ValueError) as exc_info:
tomllib.loads("f=0.1", parse_float=invalid_parse_float)
self.assertEqual(
str(exc_info.exception), "parse_float must not return dicts or lists"
)

0 comments on commit 794c8e5

Please sign in to comment.