Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use _ffi.from_buffer() to support bytearray #852

Merged
merged 14 commits into from
Nov 18, 2019
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ doc/_build/
examples/simple/*.cert
examples/simple/*.pkey
.cache
.mypy_cache
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Deprecations:
Changes:
^^^^^^^^

*none*
- Support ``bytearray`` in ``SSL.Connection.send()`` by using cffi's from_buffer.
`#852 <https://github.com/pyca/pyopenssl/pull/852>`_


----
Expand Down
72 changes: 35 additions & 37 deletions src/OpenSSL/SSL.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
UNSPECIFIED as _UNSPECIFIED,
exception_from_error_queue as _exception_from_error_queue,
ffi as _ffi,
from_buffer as _from_buffer,
lib as _lib,
make_assert as _make_assert,
native as _native,
Expand Down Expand Up @@ -1725,18 +1726,18 @@ def send(self, buf, flags=0):
# Backward compatibility
buf = _text_to_bytes_and_warn("buf", buf)

if isinstance(buf, memoryview):
buf = buf.tobytes()
if isinstance(buf, _buffer):
buf = str(buf)
if not isinstance(buf, bytes):
raise TypeError("data must be a memoryview, buffer or byte string")
if len(buf) > 2147483647:
raise ValueError("Cannot send more than 2**31-1 bytes at once.")
with _from_buffer(buf) as data:
# check len(buf) instead of len(data) for testability
if len(buf) > 2147483647:
raise ValueError(
"Cannot send more than 2**31-1 bytes at once."
)

result = _lib.SSL_write(self._ssl, data, len(data))
self._raise_ssl_error(self._ssl, result)

return result

result = _lib.SSL_write(self._ssl, buf, len(buf))
self._raise_ssl_error(self._ssl, result)
return result
write = send

def sendall(self, buf, flags=0):
Expand All @@ -1752,28 +1753,24 @@ def sendall(self, buf, flags=0):
"""
buf = _text_to_bytes_and_warn("buf", buf)

if isinstance(buf, memoryview):
buf = buf.tobytes()
if isinstance(buf, _buffer):
buf = str(buf)
if not isinstance(buf, bytes):
raise TypeError("buf must be a memoryview, buffer or byte string")

left_to_send = len(buf)
total_sent = 0
data = _ffi.new("char[]", buf)

while left_to_send:
# SSL_write's num arg is an int,
# so we cannot send more than 2**31-1 bytes at once.
result = _lib.SSL_write(
self._ssl,
data + total_sent,
min(left_to_send, 2147483647)
)
self._raise_ssl_error(self._ssl, result)
total_sent += result
left_to_send -= result
with _from_buffer(buf) as data:

left_to_send = len(buf)
total_sent = 0

while left_to_send:
# SSL_write's num arg is an int,
# so we cannot send more than 2**31-1 bytes at once.
result = _lib.SSL_write(
self._ssl,
data + total_sent,
min(left_to_send, 2147483647)
)
self._raise_ssl_error(self._ssl, result)
total_sent += result
left_to_send -= result

return total_sent

def recv(self, bufsiz, flags=None):
"""
Expand Down Expand Up @@ -1887,10 +1884,11 @@ def bio_write(self, buf):
if self._into_ssl is None:
raise TypeError("Connection sock was not None")

result = _lib.BIO_write(self._into_ssl, buf, len(buf))
if result <= 0:
self._handle_bio_errors(self._into_ssl, result)
return result
with _from_buffer(buf) as data:
result = _lib.BIO_write(self._into_ssl, data, len(data))
if result <= 0:
self._handle_bio_errors(self._into_ssl, result)
return result

def renegotiate(self):
"""
Expand Down
14 changes: 14 additions & 0 deletions src/OpenSSL/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,17 @@ def text_to_bytes_and_warn(label, obj):
)
return obj.encode('utf-8')
return obj


try:
# newer versions of cffi free the buffer deterministically
with ffi.from_buffer(b""):
pass
from_buffer = ffi.from_buffer
except AttributeError:
# cffi < 0.12 frees the buffer with refcounting gc
from contextlib import contextmanager

@contextmanager
def from_buffer(*args):
yield ffi.from_buffer(*args)
42 changes: 40 additions & 2 deletions tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,29 @@ def test_wrong_args(self, bad_context):
with pytest.raises(TypeError):
Connection(bad_context)

@pytest.mark.parametrize('bad_bio', [object(), None, 1, [1, 2, 3]])
def test_bio_write_wrong_args(self, bad_bio):
"""
`Connection.bio_write` raises `TypeError` if called with a non-bytes
(or text) argument.
"""
context = Context(TLSv1_METHOD)
connection = Connection(context, None)
with pytest.raises(TypeError):
connection.bio_write(bad_bio)

def test_bio_write(self):
"""
`Connection.bio_write` does not raise if called with bytes or
bytearray, warns if called with text.
"""
context = Context(TLSv1_METHOD)
connection = Connection(context, None)
connection.bio_write(b'xy')
connection.bio_write(bytearray(b'za'))
with pytest.warns(DeprecationWarning):
connection.bio_write(u'deprecated')

def test_get_context(self):
"""
`Connection.get_context` returns the `Context` instance used to
Expand Down Expand Up @@ -2804,6 +2827,8 @@ def test_wrong_args(self):
connection = Connection(Context(TLSv1_METHOD), None)
with pytest.raises(TypeError):
connection.send(object())
with pytest.raises(TypeError):
connection.send([1, 2, 3])

def test_short_bytes(self):
"""
Expand Down Expand Up @@ -2842,6 +2867,16 @@ def test_short_memoryview(self):
assert count == 2
assert client.recv(2) == b'xy'

def test_short_bytearray(self):
"""
When passed a short bytearray, `Connection.send` transmits all of
it and returns the number of bytes sent.
"""
server, client = loopback()
count = server.send(bytearray(b'xy'))
assert count == 2
assert client.recv(2) == b'xy'

@skip_if_py3
def test_short_buffer(self):
"""
Expand Down Expand Up @@ -3012,6 +3047,8 @@ def test_wrong_args(self):
connection = Connection(Context(TLSv1_METHOD), None)
with pytest.raises(TypeError):
connection.sendall(object())
with pytest.raises(TypeError):
connection.sendall([1, 2, 3])

def test_short(self):
"""
Expand Down Expand Up @@ -3053,8 +3090,9 @@ def test_short_buffers(self):
`Connection.sendall` transmits all of them.
"""
server, client = loopback()
server.sendall(buffer(b'x'))
assert client.recv(1) == b'x'
count = server.sendall(buffer(b'xy'))
assert count == 2
assert client.recv(2) == b'xy'

def test_long(self):
"""
Expand Down