Skip to content

Commit

Permalink
Corrects BitString decoding
Browse files Browse the repository at this point in the history
Properly handles initial octet which stores the number of unused bits
and removes the corresponding number of bits.

Adds corresponding unit tests.
  • Loading branch information
0xbf00 committed Feb 28, 2022
1 parent a7d95a9 commit 18b3b7d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ def _read_value(self, cls, nr, length): # type: (int, int, int) -> any
value = self._decode_object_identifier(bytes_data)
elif nr in (Numbers.PrintableString, Numbers.IA5String, Numbers.UTCTime):
value = self._decode_printable_string(bytes_data)
elif nr == Numbers.BitString:
value = self._decode_bitstring(bytes_data)
else:
value = bytes_data
return value
Expand Down Expand Up @@ -647,3 +649,28 @@ def _decode_object_identifier(bytes_data): # type: (bytes) -> str
def _decode_printable_string(bytes_data): # type: (bytes) -> str
"""Decode a printable string."""
return bytes_data.decode('utf-8')

@staticmethod
def _decode_bitstring(bytes_data): # type: (bytes) -> str
"""Decode a bitstring."""
if len(bytes_data) == 0:
raise Error('ASN1 syntax error')

num_unused_bits = bytes_data[0]
if not (0 <= num_unused_bits <= 7):
raise Error('ASN1 syntax error')

if num_unused_bits == 0:
return bytes_data[1:]

# Shift off unused bits
remaining = bytearray(bytes_data[1:])
bitmask = (1 << num_unused_bits) - 1
removed_bits = 0

for i in range(len(remaining)):
byte = int(remaining[i])
remaining[i] = (byte >> num_unused_bits) | (removed_bits << num_unused_bits)
removed_bits = byte & bitmask

return bytes(remaining)
24 changes: 24 additions & 0 deletions tests/test_asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,24 @@ def test_printable_string(self):
tag, val = dec.read()
assert val == u'foo'

def test_bitstring(self):
buf = b'\x03\x04\x00\x12\x34\x56'
dec = asn1.Decoder()
dec.start(buf)
tag = dec.peek()
assert tag == (asn1.Numbers.BitString, asn1.Types.Primitive, asn1.Classes.Universal)
tag, val = dec.read()
assert val == b'\x12\x34\x56'

def test_bitstring_unused_bits(self):
buf = b'\x03\x04\x04\x12\x34\x50'
dec = asn1.Decoder()
dec.start(buf)
tag = dec.peek()
assert tag == (asn1.Numbers.BitString, asn1.Types.Primitive, asn1.Classes.Universal)
tag, val = dec.read()
assert val == b'\x01\x23\x45'

def test_unicode_printable_string(self):
buf = b'\x13\x05\x66\x6f\x6f\xc3\xa9'
dec = asn1.Decoder()
Expand Down Expand Up @@ -703,6 +721,12 @@ def test_error_object_identifier_with_too_large_first_component(self):
dec.start(buf)
pytest.raises(asn1.Error, dec.read)

def test_error_bitstring_with_too_many_unused_bits(self):
buf = b'\x03\x04\x08\x12\x34\x50'
dec = asn1.Decoder()
dec.start(buf)
pytest.raises(asn1.Error, dec.read)

def test_big_negative_integer(self):
buf = b'\x02\x10\xff\x7f\x2b\x3a\x4d\xea\x48\x1e\x1f\x37\x7b\xa8\xbd\x7f\xb0\x16'
dec = asn1.Decoder()
Expand Down

0 comments on commit 18b3b7d

Please sign in to comment.