diff --git a/pytm/pytm.py b/pytm/pytm.py index aaae8ce..944c854 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -355,7 +355,7 @@ def _apply_defaults(flows, data): try: e.overrides = e.sink.overrides e.overrides.extend( - f for f in e.source.overrides if f.id not in (f.id for f in e.overrides) + f for f in e.source.overrides if f.threat_id not in (f.threat_id for f in e.overrides) ) except ValueError: pass @@ -532,7 +532,8 @@ class Finding: severity = varString("", required=True, doc="Threat severity") mitigations = varString("", required=True, doc="Threat mitigations") example = varString("", required=True, doc="Threat example") - id = varString("", required=True, doc="Threat ID") + id = varInt("", required=True, doc="Finding ID") + threat_id = varString("", required=True, doc="Threat ID") references = varString("", required=True, doc="Threat references") response = varString( "", @@ -565,18 +566,18 @@ def __init__( "severity", "mitigations", "example", - "id", "references", ] threat = kwargs.pop("threat", None) if threat: + kwargs["threat_id"] = getattr(threat, "id") for a in attrs: # copy threat attrs into kwargs to allow to override them in next step kwargs[a] = getattr(threat, a) - threat_id = kwargs.get("id", None) + threat_id = kwargs.get("threat_id", None) for f in element.overrides: - if f.id != threat_id: + if f.threat_id != threat_id: continue for i in dir(f.__class__): attr = getattr(f.__class__, i) @@ -664,24 +665,27 @@ def _add_threats(self): TM._threats.append(Threat(**i)) def resolve(self): + finding_count = 0; findings = [] elements = defaultdict(list) for e in TM._elements: if not e.inScope: continue - override_ids = set(f.id for f in e.overrides) + override_ids = set(f.threat_id for f in e.overrides) # if element is a dataflow filter out overrides from source and sink # because they will be always applied there anyway try: - override_ids -= set(f.id for f in e.source.overrides + e.sink.overrides) + override_ids -= set(f.threat_id for f in e.source.overrides + e.sink.overrides) except AttributeError: pass for t in TM._threats: if not t.apply(e) and t.id not in override_ids: continue - f = Finding(e, threat=t) + + finding_count += 1 + f = Finding(e, id=finding_count, threat=t) findings.append(f) elements[e].append(f) self.findings = findings @@ -702,7 +706,7 @@ def check(self): _apply_defaults(TM._flows, TM._data) for e in TM._elements: - top = Counter(f.id for f in e.overrides).most_common(1) + top = Counter(f.threat_id for f in e.overrides).most_common(1) if not top: continue threat_id, count = top[0] diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py index 78aee10..d6f7fa5 100644 --- a/tests/test_pytmfunc.py +++ b/tests/test_pytmfunc.py @@ -201,16 +201,16 @@ def test_resolve(self): self.maxDiff = None self.assertEqual( - [f.id for f in tm.findings], + [f.threat_id for f in tm.findings], ["Server", "Datastore", "Dataflow", "Dataflow", "Dataflow", "Dataflow"], ) - self.assertEqual([f.id for f in user.findings], []) - self.assertEqual([f.id for f in web.findings], ["Server"]) - self.assertEqual([f.id for f in db.findings], ["Datastore"]) - self.assertEqual([f.id for f in req.findings], ["Dataflow"]) - self.assertEqual([f.id for f in query.findings], ["Dataflow"]) - self.assertEqual([f.id for f in results.findings], ["Dataflow"]) - self.assertEqual([f.id for f in resp.findings], ["Dataflow"]) + self.assertEqual([f.threat_id for f in user.findings], []) + self.assertEqual([f.threat_id for f in web.findings], ["Server"]) + self.assertEqual([f.threat_id for f in db.findings], ["Datastore"]) + self.assertEqual([f.threat_id for f in req.findings], ["Dataflow"]) + self.assertEqual([f.threat_id for f in query.findings], ["Dataflow"]) + self.assertEqual([f.threat_id for f in results.findings], ["Dataflow"]) + self.assertEqual([f.threat_id for f in resp.findings], ["Dataflow"]) def test_overrides(self): random.seed(0) @@ -223,7 +223,7 @@ def test_overrides(self): web = Server( "Web Server", overrides=[ - Finding(id="Server", response="mitigated by adding TLS"), + Finding(threat_id="Server", response="mitigated by adding TLS"), ], ) db = Datastore( @@ -231,7 +231,7 @@ def test_overrides(self): inBoundary=server_db, overrides=[ Finding( - id="Datastore", response="accepted since inside the trust boundary" + threat_id="Datastore", response="accepted since inside the trust boundary" ), ], ) @@ -249,7 +249,7 @@ def test_overrides(self): self.maxDiff = None self.assertEqual( - [f.id for f in tm.findings], + [f.threat_id for f in tm.findings], ["Server", "Datastore"], ) self.assertEqual(