From e36f34597ca96c40dbb81d63a8fe9ac0860b8dc3 Mon Sep 17 00:00:00 2001 From: Nate Prewitt Date: Wed, 8 Jun 2022 12:03:56 -0600 Subject: [PATCH] Add valdation for header name (#6154) --- requests/_internal_utils.py | 11 +++++ requests/utils.py | 35 +++++++-------- tests/test_requests.py | 88 +++++++++++++++++++------------------ 3 files changed, 72 insertions(+), 62 deletions(-) diff --git a/requests/_internal_utils.py b/requests/_internal_utils.py index a82e36c8cb..7dc9bc5336 100644 --- a/requests/_internal_utils.py +++ b/requests/_internal_utils.py @@ -5,9 +5,20 @@ Provides utility functions that are consumed internally by Requests which depend on extremely few external helpers (such as compat) """ +import re from .compat import builtin_str +_VALID_HEADER_NAME_RE_BYTE = re.compile(rb"^[^:\s][^:\r\n]*$") +_VALID_HEADER_NAME_RE_STR = re.compile(r"^[^:\s][^:\r\n]*$") +_VALID_HEADER_VALUE_RE_BYTE = re.compile(rb"^\S[^\r\n]*$|^$") +_VALID_HEADER_VALUE_RE_STR = re.compile(r"^\S[^\r\n]*$|^$") + +HEADER_VALIDATORS = { + bytes: (_VALID_HEADER_NAME_RE_BYTE, _VALID_HEADER_VALUE_RE_BYTE), + str: (_VALID_HEADER_NAME_RE_STR, _VALID_HEADER_VALUE_RE_STR), +} + def to_native_string(string, encoding="ascii"): """Given a string object, regardless of type, returns a representation of diff --git a/requests/utils.py b/requests/utils.py index 1baab8ccad..ad5358381a 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -25,7 +25,7 @@ from .__version__ import __version__ # to_native_string is unused here, but imported here for backwards compatibility -from ._internal_utils import to_native_string # noqa: F401 +from ._internal_utils import HEADER_VALIDATORS, to_native_string # noqa: F401 from .compat import ( Mapping, basestring, @@ -1024,33 +1024,30 @@ def get_auth_from_url(url): return auth -# Moved outside of function to avoid recompile every call -_CLEAN_HEADER_REGEX_BYTE = re.compile(b"^\\S[^\\r\\n]*$|^$") -_CLEAN_HEADER_REGEX_STR = re.compile(r"^\S[^\r\n]*$|^$") - - def check_header_validity(header): - """Verifies that header value is a string which doesn't contain - leading whitespace or return characters. This prevents unintended - header injection. + """Verifies that header parts don't contain leading whitespace + reserved characters, or return characters. :param header: tuple, in the format (name, value). """ name, value = header - if isinstance(value, bytes): - pat = _CLEAN_HEADER_REGEX_BYTE - else: - pat = _CLEAN_HEADER_REGEX_STR - try: - if not pat.match(value): + for part in header: + if type(part) not in HEADER_VALIDATORS: raise InvalidHeader( - f"Invalid return character or leading space in header: {name}" + f"Header part ({part!r}) from {{{name!r}: {value!r}}} must be " + f"of type str or bytes, not {type(part)}" ) - except TypeError: + + _validate_header_part(name, "name", HEADER_VALIDATORS[type(name)][0]) + _validate_header_part(value, "value", HEADER_VALIDATORS[type(value)][1]) + + +def _validate_header_part(header_part, header_kind, validator): + if not validator.match(header_part): raise InvalidHeader( - f"Value for header {{{name}: {value}}} must be of type " - f"str or bytes, not {type(value)}" + f"Invalid leading whitespace, reserved character(s), or return" + f"character(s) in header {header_kind}: {header_part!r}" ) diff --git a/tests/test_requests.py b/tests/test_requests.py index b724264166..5b4c3f53fa 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1096,7 +1096,7 @@ def test_non_prepared_request_error(self): def test_custom_content_type(self, httpbin): with open(__file__, "rb") as f1: with open(__file__, "rb") as f2: - data={"stuff": json.dumps({"a": 123})} + data = {"stuff": json.dumps({"a": 123})} files = { "file1": ("test_requests.py", f1), "file2": ("test_requests", f2, "text/py-content-type"), @@ -1682,68 +1682,70 @@ def test_header_keys_are_native(self, httpbin): def test_header_validation(self, httpbin): """Ensure prepare_headers regex isn't flagging valid header contents.""" - headers_ok = { + valid_headers = { "foo": "bar baz qux", "bar": b"fbbq", "baz": "", "qux": "1", } - r = requests.get(httpbin("get"), headers=headers_ok) - assert r.request.headers["foo"] == headers_ok["foo"] + r = requests.get(httpbin("get"), headers=valid_headers) + for key in valid_headers.keys(): + valid_headers[key] == r.request.headers[key] - def test_header_value_not_str(self, httpbin): + @pytest.mark.parametrize( + "invalid_header, key", + ( + ({"foo": 3}, "foo"), + ({"bar": {"foo": "bar"}}, "bar"), + ({"baz": ["foo", "bar"]}, "baz"), + ), + ) + def test_header_value_not_str(self, httpbin, invalid_header, key): """Ensure the header value is of type string or bytes as per discussion in GH issue #3386 """ - headers_int = {"foo": 3} - headers_dict = {"bar": {"foo": "bar"}} - headers_list = {"baz": ["foo", "bar"]} - - # Test for int - with pytest.raises(InvalidHeader) as excinfo: - requests.get(httpbin("get"), headers=headers_int) - assert "foo" in str(excinfo.value) - # Test for dict with pytest.raises(InvalidHeader) as excinfo: - requests.get(httpbin("get"), headers=headers_dict) - assert "bar" in str(excinfo.value) - # Test for list - with pytest.raises(InvalidHeader) as excinfo: - requests.get(httpbin("get"), headers=headers_list) - assert "baz" in str(excinfo.value) + requests.get(httpbin("get"), headers=invalid_header) + assert key in str(excinfo.value) - def test_header_no_return_chars(self, httpbin): + @pytest.mark.parametrize( + "invalid_header", + ( + {"foo": "bar\r\nbaz: qux"}, + {"foo": "bar\n\rbaz: qux"}, + {"foo": "bar\nbaz: qux"}, + {"foo": "bar\rbaz: qux"}, + {"fo\ro": "bar"}, + {"fo\r\no": "bar"}, + {"fo\n\ro": "bar"}, + {"fo\no": "bar"}, + ), + ) + def test_header_no_return_chars(self, httpbin, invalid_header): """Ensure that a header containing return character sequences raise an exception. Otherwise, multiple headers are created from single string. """ - headers_ret = {"foo": "bar\r\nbaz: qux"} - headers_lf = {"foo": "bar\nbaz: qux"} - headers_cr = {"foo": "bar\rbaz: qux"} - - # Test for newline - with pytest.raises(InvalidHeader): - requests.get(httpbin("get"), headers=headers_ret) - # Test for line feed - with pytest.raises(InvalidHeader): - requests.get(httpbin("get"), headers=headers_lf) - # Test for carriage return with pytest.raises(InvalidHeader): - requests.get(httpbin("get"), headers=headers_cr) + requests.get(httpbin("get"), headers=invalid_header) - def test_header_no_leading_space(self, httpbin): + @pytest.mark.parametrize( + "invalid_header", + ( + {" foo": "bar"}, + {"\tfoo": "bar"}, + {" foo": "bar"}, + {"foo": " bar"}, + {"foo": " bar"}, + {"foo": "\tbar"}, + {" ": "bar"}, + ), + ) + def test_header_no_leading_space(self, httpbin, invalid_header): """Ensure headers containing leading whitespace raise InvalidHeader Error before sending. """ - headers_space = {"foo": " bar"} - headers_tab = {"foo": " bar"} - - # Test for whitespace - with pytest.raises(InvalidHeader): - requests.get(httpbin("get"), headers=headers_space) - - # Test for tab with pytest.raises(InvalidHeader): - requests.get(httpbin("get"), headers=headers_tab) + requests.get(httpbin("get"), headers=invalid_header) @pytest.mark.parametrize("files", ("foo", b"foo", bytearray(b"foo"))) def test_can_send_objects_with_files(self, httpbin, files):