From 8a964277ab3285591c93d6e9f9fc3288206eec54 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Fri, 7 Jun 2024 11:38:00 +0100 Subject: [PATCH] Set permissive schema on node metadata Also no longer need to fill empty json metadata with `{}` in tests/util.py --- tests/test_inference.py | 27 ++++++++++---------- tests/tsutil.py | 4 +-- tsinfer/inference.py | 56 +++++++++++++++++++++++++---------------- 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index b7f860b5..bffd02dd 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1291,8 +1291,16 @@ def test_from_standard_tree_sequence(self): assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata) # Unless inference is perfect, internal nodes may differ, but sample nodes # should be identical - for n1, n2 in zip(ts.samples(), ts_inferred.samples()): - assert ts.node(n1) == ts_inferred.node(n2) + for u1, u2 in zip(ts.samples(), ts_inferred.samples()): + # NB - flags might differ if e.g. the node is a historical sample + # but original ones should be maintained + n1 = ts.node(u1) + n2 = ts.node(u2) + assert (n1.flags & n2.flags) == n1.flags # n1.flags is subset of n2.flags + assert n1.time == n2.time + assert n1.population == n2.population + assert n1.individual == n2.individual + assert tsutil.json_metadata_is_subset(n1.metadata, n2.metadata) # Sites can have metadata added by the inference process, but inferred site # metadata should always include all the metadata in the original ts for s1, s2 in zip(ts.sites(), ts_inferred.sites()): @@ -1906,7 +1914,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None): ancestors_time = ancestor_data.ancestors_time[:] num_ancestor_nodes = 0 for n in ancestors_ts.nodes(): - md = json.loads(n.metadata) if n.metadata else {} + md = n.metadata if tsinfer.is_pc_ancestor(n.flags): assert not ("ancestor_data_id" in md) else: @@ -2966,7 +2974,6 @@ def verify(self, ts): last_node = ts1.node(ts1.num_nodes - 1) assert np.max(ts1.tables.nodes.time) == last_node.time md = last_node.metadata - md = json.loads(md.decode()) # At the moment node metadata has no schema assert md.get("ancestor_data_id", None) != 0 # When not post processing and there is no path compression, @@ -2977,7 +2984,6 @@ def verify(self, ts): first_node = ts2.node(0) assert np.max(ts2.tables.nodes.time) == first_node.time md = first_node.metadata - md = json.loads(md.decode()) # At the moment node metadata has no schema assert md["ancestor_data_id"] == 0 @pytest.mark.parametrize("simp", [True, False]) @@ -3015,7 +3021,6 @@ def test_standalone_post_process(self, medium_sd_fixture): oldest_parent_id = ts_unsimplified.edge(-1).parent assert oldest_parent_id == 0 md = ts_unsimplified.node(oldest_parent_id).metadata - md = json.loads(md.decode()) # At the moment node metadata has no schema assert md["ancestor_data_id"] == 0 # Post processing removes ancestor_data_id 0 @@ -3024,7 +3029,6 @@ def test_standalone_post_process(self, medium_sd_fixture): oldest_parent_id = ts.edge(-1).parent assert np.sum(ts.tables.nodes.time == ts.node(oldest_parent_id).time) == 1 md = ts.node(oldest_parent_id).metadata - md = json.loads(md.decode()) # At the moment node metadata has no schema assert md["ancestor_data_id"] == 1 ts = tsinfer.post_process( @@ -3036,7 +3040,6 @@ def test_standalone_post_process(self, medium_sd_fixture): for tree in ts.trees(): roots.add(tree.root) md = ts.node(tree.root).metadata - md = json.loads(md.decode()) # At the moment node metadata has no schema assert md["ancestor_data_id"] == 1 assert len(roots) > 1 @@ -3815,8 +3818,7 @@ def verify_augmented_ancestors( node = t2.nodes[m + j] assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR assert node.time == 1 - metadata = json.loads(node.metadata.decode()) - assert node_id == metadata["sample_data_id"] + assert node_id == node.metadata["sample_data_id"] t2.nodes.truncate(len(t1.nodes)) # Adding and subtracting 1 can lead to small diffs, so we compare @@ -3824,7 +3826,7 @@ def verify_augmented_ancestors( t2.nodes.time -= 1.0 assert np.allclose(t2.nodes.time, t1.nodes.time) t2.nodes.time = t1.nodes.time - assert t1.nodes == t2.nodes + t1.nodes.assert_equals(t2.nodes, ignore_metadata=True) if not path_compression: # If we have path compression it's possible that some older edges # will be compressed out. @@ -3966,8 +3968,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression): num_sample_ancestors = 0 for node in final_ts.nodes(): if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR: - metadata = json.loads(node.metadata.decode()) - assert metadata["sample_data_id"] in subset + assert node.metadata["sample_data_id"] in subset num_sample_ancestors += 1 assert expected_sample_ancestors == num_sample_ancestors tsinfer.verify(samples, final_ts.simplify()) diff --git a/tests/tsutil.py b/tests/tsutil.py index 423acd46..e17f5626 100644 --- a/tests/tsutil.py +++ b/tests/tsutil.py @@ -59,10 +59,10 @@ def add_default_schemas(ts): tables.populations[pop.id] = pop tables.individuals.metadata_schema = schema assert len(tables.individuals.metadata) == 0 - tables.individuals.packset_metadata([b"{}"] * ts.num_individuals) tables.sites.metadata_schema = schema assert len(tables.sites.metadata) == 0 - tables.sites.packset_metadata([b"{}"] * ts.num_sites) + tables.nodes.metadata_schema = schema + assert len(tables.nodes.metadata) == 0 return tables.tree_sequence() diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 503692b6..e8dcb876 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -81,6 +81,21 @@ ], } +node_ancestor_data_id_metadata_definition = { + "description": ( + "The ID of the tsinfer ancestor data node from which this node is derived." + ), + "type": "number", +} + +node_sample_data_id_metadata_definition = { + "description": ( + "The ID of the tsinfer sample data node from which this node is derived. " + "Only present for nodes in which historical samples are treated as ancestors." + ), + "type": "number", +} + class LMDBCache: def __init__(self, lmdb): @@ -1879,7 +1894,7 @@ def match_ancestors(self): logger.info("Finished ancestor matching") return ts - def get_ancestors_tables(self): + def fill_ancestors_tables(self, tables): """ Return the ancestors tree sequence tables. Only inference sites are included in this tree sequence. All nodes have the sample flag bit set, and if a node @@ -1888,21 +1903,10 @@ def get_ancestors_tables(self): logger.debug("Building ancestors tree sequence") tsb = self.tree_sequence_builder - tables = tskit.TableCollection( - sequence_length=self.ancestor_data.sequence_length - ) - flags, times = tsb.dump_nodes() pc_ancestors = is_pc_ancestor(flags) tables.nodes.set_columns(flags=flags, time=times) - # # FIXME we should do this as a struct codec? - # dict_schema = permissive_json_schema() - # dict_schema = add_to_schema(dict_schema, "ancestor_data_id", - # {"type": "integer"}) - # schema = tskit.MetadataSchema(dict_schema) - # tables.nodes.schema = schema - # Add metadata for any non-PC node, pointing to the original ancestor metadata = [] ancestor = 0 @@ -1940,16 +1944,20 @@ def get_ancestors_tables(self): len(tables.sites), ) ) - return tables def store_output(self): + tables = tskit.TableCollection( + sequence_length=self.ancestor_data.sequence_length + ) + # We decided to use a permissive schema for the metadata, for flexibility + dict_schema = tskit.MetadataSchema.permissive_json().schema + dict_schema = add_to_schema( + dict_schema, "ancestor_data_id", node_ancestor_data_id_metadata_definition + ) + tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema) + if self.num_ancestors > 0: - tables = self.get_ancestors_tables() - else: - # Allocate an empty tree sequence. - tables = tskit.TableCollection( - sequence_length=self.ancestor_data.sequence_length - ) + self.fill_ancestors_tables(tables) tables.time_units = self.time_units return tables.tree_sequence() @@ -2385,6 +2393,12 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes): logger.debug("Building augmented ancestors tree sequence") tsb = self.tree_sequence_builder tables = self.ancestors_ts_tables.copy() + dict_schema = tables.nodes.metadata_schema.schema + assert dict_schema is not None + dict_schema = add_to_schema( + dict_schema, "sample_data_id", node_sample_data_id_metadata_definition + ) + tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema) flags, times = tsb.dump_nodes() s = 0 @@ -2395,9 +2409,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes): tables.nodes.add_row( flags=constants.NODE_IS_SAMPLE_ANCESTOR, time=times[j], - metadata=_encode_raw_metadata( - {"sample_data_id": int(sample_indexes[s])} - ), + metadata={"sample_data_id": int(sample_indexes[s])}, ) s += 1 else: