Skip to content

Commit

Permalink
A decent implementation of Comparable.
Browse files Browse the repository at this point in the history
I don't think we need anything more complicated like the solutions here:

python/typing#59

There are some questions about the correctness of other operators like
<= and >= that you now get for free with functools.total_ordering, due
to the unexpected interactions with the __eq__ you get for free from
dataclasses.dataclass, but we can punt them to the future.

Signed-off-by: Trishank Karthik Kuppusamy <trishank.kuppusamy@datadoghq.com>
  • Loading branch information
trishankatdatadog committed Nov 30, 2020
1 parent 081e581 commit f8d3437
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ source = ["tuf_on_a_plane"]

[tool.coverage.report]
show_missing = true
fail_under = 80
fail_under = 78
10 changes: 9 additions & 1 deletion src/tuf_on_a_plane/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@ def strptime(cls, date_string: str, format: str) -> "DateTime":


@total_ordering
class Natural:
class Comparable:
def __eq__(self, other: Any) -> bool:
raise NotImplementedError

def __lt__(self, other: Any) -> bool:
raise NotImplementedError


class Natural(Comparable):
def __init__(self, value: Any):
self.value = value

Expand Down
25 changes: 17 additions & 8 deletions src/tuf_on_a_plane/models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from securesystemslib.rsa_keys import verify_rsa_signature

from .common import (
Comparable,
DateTime,
Filepath,
Hashes,
Expand All @@ -26,19 +27,23 @@


@dataclass
class Signed:
class Signed(Comparable):
# __eq__ and __str__ autogenerated by dataclass.
expires: DateTime
spec_version: SpecVersion
version: Version

def __lt__(self, other: Any) -> bool:
if not isinstance(other, Signed):
raise NotImplementedError
return self.version < other.version

# NOTE: Overrides __eq__ and __lt__.
def __gt__(self, other: Any) -> bool:
if not isinstance(other, Signed):
raise NotImplementedError
return self.version > other.version

def __str__(self):
return f"{self.__class__.__name__}({self.version})"


@dataclass
class Metadata:
Expand Down Expand Up @@ -149,19 +154,23 @@ class Root(Signed):


@dataclass
class TimeSnap:
class TimeSnap(Comparable):
# __eq__ and __str__ autogenerated by dataclass.
version: Version
hashes: Optional[Hashes] = None
length: Optional[Length] = None

def __lt__(self, other: Any) -> bool:
if not isinstance(other, TimeSnap):
raise NotImplementedError
return self.version < other.version

# NOTE: Overrides __eq__ and __lt__.
def __gt__(self, other: Any) -> bool:
if not isinstance(other, TimeSnap):
raise NotImplementedError
return self.version > other.version

def __str__(self):
return f"{self.__class__.__name__}({self.version})"


TimeSnaps = Dict[Filepath, TimeSnap]

Expand Down
65 changes: 32 additions & 33 deletions src/tuf_on_a_plane/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
RollbackAttack,
)
from .models.common import (
Comparable,
Filepath,
Rolename,
Url,
Expand Down Expand Up @@ -49,9 +50,7 @@ def __check_expiry(self, signed: Signed) -> None:
if signed.expires <= self.config.NOW:
raise FreezeAttack(f"{signed}: {signed.expires} <= {self.config.NOW}")

# FIXME: maybe we can write a Comparable interface, but I'm too tired right
# now: https://github.com/python/typing/issues/59
def __check_rollback(self, prev: Signed, curr: Signed) -> None:
def __check_rollback(self, prev: Comparable, curr: Comparable) -> None:
if prev > curr:
raise RollbackAttack(f"{prev} > {curr}")

Expand Down Expand Up @@ -113,7 +112,8 @@ def __update_root(self) -> None:
"""5.1. Update the root metadata file."""
# 5.1.1. Let N denote the version number of the trusted root metadata
# file.
curr_root = self.__root
prev_root = self.__root
curr_root = prev_root
n = curr_root.version

# 5.1.8. Repeat steps 5.1.1 to 5.1.8.
Expand Down Expand Up @@ -144,32 +144,35 @@ def __update_root(self) -> None:
# file.
curr_root = metadata.signed

# 5.1.11. Set whether consistent snapshots are used as per the trusted
# root metadata file (see Section 4.3).
# NOTE: We violate the spec in checking this *before* deleting local
# timestamp and/or snapshot metadata, which I think is reasonable.
if not curr_root.consistent_snapshot:
raise NoConsistentSnapshotsError

# 5.1.9. Check for a freeze attack.
self.__check_expiry(curr_root)

# 5.1.10. If the timestamp and / or snapshot keys have been rotated,
# then delete the trusted timestamp and snapshot metadata files.
if (
self.__root.timestamp != curr_root.timestamp
or self.__root.snapshot != curr_root.snapshot
):
self.rm_file(self.__local_metadata_filename("snapshot"), ignore_errors=True)
self.rm_file(
self.__local_metadata_filename("timestamp"), ignore_errors=True
)
if prev_root < curr_root:
# 5.1.11. Set whether consistent snapshots are used as per the
# trusted root metadata file.
# NOTE: We violate the spec in checking this *before* deleting local
# timestamp and/or snapshot metadata, which I think is reasonable.
if not curr_root.consistent_snapshot:
raise NoConsistentSnapshotsError

# 5.1.10. If the timestamp and / or snapshot keys have been rotated,
# then delete the trusted timestamp and snapshot metadata files.
if (
self.__root.timestamp != curr_root.timestamp
or self.__root.snapshot != curr_root.snapshot
):
self.rm_file(
self.__local_metadata_filename("snapshot"), ignore_errors=True
)
self.rm_file(
self.__local_metadata_filename("timestamp"), ignore_errors=True
)

# 5.1.7. Persist root metadata.
# NOTE: We violate the spec in persisting only *after* checking
# everything, which I think is reasonable.
self.mv_file(tmp_file, self.__local_metadata_filename("root"))
self.__root = curr_root
# 5.1.7. Persist root metadata.
# NOTE: We violate the spec in persisting only *after* checking
# everything, which I think is reasonable.
self.mv_file(tmp_file, self.__local_metadata_filename("root"))
self.__root = curr_root

def __update_timestamp(self) -> None:
"""5.2. Download the timestamp metadata file."""
Expand All @@ -192,13 +195,9 @@ def __update_timestamp(self) -> None:
TimeSnap, prev_metadata.signed.snapshot
)
self.__check_rollback(prev_metadata.signed, curr_metadata.signed)

# FIXME: ideally, self.__check_rollback() takes Comparable so that
# we can reuse it.
if prev_metadata.signed.snapshot > curr_metadata.signed.snapshot:
raise RollbackAttack(
f"{prev_metadata.signed.snapshot} > {curr_metadata.signed.snapshot}"
)
self.__check_rollback(
prev_metadata.signed.snapshot, curr_metadata.signed.snapshot
)

# 5.2.3. Check for a freeze attack.
self.__check_expiry(curr_metadata.signed)
Expand Down

0 comments on commit f8d3437

Please sign in to comment.