Skip to content

Commit

Permalink
Merge pull request #931 from hyanwong/set-node-md-schema
Browse files Browse the repository at this point in the history
Set node metadata schema
  • Loading branch information
benjeffery authored Jul 17, 2024
2 parents 141f0c7 + 8a96427 commit a41d838
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 37 deletions.
27 changes: 14 additions & 13 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,16 @@ def test_from_standard_tree_sequence(self):
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
# Unless inference is perfect, internal nodes may differ, but sample nodes
# should be identical
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
assert ts.node(n1) == ts_inferred.node(n2)
for u1, u2 in zip(ts.samples(), ts_inferred.samples()):
# NB - flags might differ if e.g. the node is a historical sample
# but original ones should be maintained
n1 = ts.node(u1)
n2 = ts.node(u2)
assert (n1.flags & n2.flags) == n1.flags # n1.flags is subset of n2.flags
assert n1.time == n2.time
assert n1.population == n2.population
assert n1.individual == n2.individual
assert tsutil.json_metadata_is_subset(n1.metadata, n2.metadata)
# Sites can have metadata added by the inference process, but inferred site
# metadata should always include all the metadata in the original ts
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
Expand Down Expand Up @@ -1906,7 +1914,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
ancestors_time = ancestor_data.ancestors_time[:]
num_ancestor_nodes = 0
for n in ancestors_ts.nodes():
md = json.loads(n.metadata) if n.metadata else {}
md = n.metadata
if tsinfer.is_pc_ancestor(n.flags):
assert not ("ancestor_data_id" in md)
else:
Expand Down Expand Up @@ -2966,7 +2974,6 @@ def verify(self, ts):
last_node = ts1.node(ts1.num_nodes - 1)
assert np.max(ts1.tables.nodes.time) == last_node.time
md = last_node.metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md.get("ancestor_data_id", None) != 0

# When not post processing and there is no path compression,
Expand All @@ -2977,7 +2984,6 @@ def verify(self, ts):
first_node = ts2.node(0)
assert np.max(ts2.tables.nodes.time) == first_node.time
md = first_node.metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 0

@pytest.mark.parametrize("simp", [True, False])
Expand Down Expand Up @@ -3015,7 +3021,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
oldest_parent_id = ts_unsimplified.edge(-1).parent
assert oldest_parent_id == 0
md = ts_unsimplified.node(oldest_parent_id).metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 0

# Post processing removes ancestor_data_id 0
Expand All @@ -3024,7 +3029,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
oldest_parent_id = ts.edge(-1).parent
assert np.sum(ts.tables.nodes.time == ts.node(oldest_parent_id).time) == 1
md = ts.node(oldest_parent_id).metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 1

ts = tsinfer.post_process(
Expand All @@ -3036,7 +3040,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
for tree in ts.trees():
roots.add(tree.root)
md = ts.node(tree.root).metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 1
assert len(roots) > 1

Expand Down Expand Up @@ -3815,16 +3818,15 @@ def verify_augmented_ancestors(
node = t2.nodes[m + j]
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
assert node.time == 1
metadata = json.loads(node.metadata.decode())
assert node_id == metadata["sample_data_id"]
assert node_id == node.metadata["sample_data_id"]

t2.nodes.truncate(len(t1.nodes))
# Adding and subtracting 1 can lead to small diffs, so we compare
# the time separately.
t2.nodes.time -= 1.0
assert np.allclose(t2.nodes.time, t1.nodes.time)
t2.nodes.time = t1.nodes.time
assert t1.nodes == t2.nodes
t1.nodes.assert_equals(t2.nodes, ignore_metadata=True)
if not path_compression:
# If we have path compression it's possible that some older edges
# will be compressed out.
Expand Down Expand Up @@ -3966,8 +3968,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
num_sample_ancestors = 0
for node in final_ts.nodes():
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
metadata = json.loads(node.metadata.decode())
assert metadata["sample_data_id"] in subset
assert node.metadata["sample_data_id"] in subset
num_sample_ancestors += 1
assert expected_sample_ancestors == num_sample_ancestors
tsinfer.verify(samples, final_ts.simplify())
Expand Down
4 changes: 2 additions & 2 deletions tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def add_default_schemas(ts):
tables.populations[pop.id] = pop
tables.individuals.metadata_schema = schema
assert len(tables.individuals.metadata) == 0
tables.individuals.packset_metadata([b"{}"] * ts.num_individuals)
tables.sites.metadata_schema = schema
assert len(tables.sites.metadata) == 0
tables.sites.packset_metadata([b"{}"] * ts.num_sites)
tables.nodes.metadata_schema = schema
assert len(tables.nodes.metadata) == 0
return tables.tree_sequence()


Expand Down
56 changes: 34 additions & 22 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@
],
}

node_ancestor_data_id_metadata_definition = {
"description": (
"The ID of the tsinfer ancestor data node from which this node is derived."
),
"type": "number",
}

node_sample_data_id_metadata_definition = {
"description": (
"The ID of the tsinfer sample data node from which this node is derived. "
"Only present for nodes in which historical samples are treated as ancestors."
),
"type": "number",
}


class LMDBCache:
def __init__(self, lmdb):
Expand Down Expand Up @@ -1879,7 +1894,7 @@ def match_ancestors(self):
logger.info("Finished ancestor matching")
return ts

def get_ancestors_tables(self):
def fill_ancestors_tables(self, tables):
"""
Return the ancestors tree sequence tables. Only inference sites are included in
this tree sequence. All nodes have the sample flag bit set, and if a node
Expand All @@ -1888,21 +1903,10 @@ def get_ancestors_tables(self):
logger.debug("Building ancestors tree sequence")
tsb = self.tree_sequence_builder

tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)

flags, times = tsb.dump_nodes()
pc_ancestors = is_pc_ancestor(flags)
tables.nodes.set_columns(flags=flags, time=times)

# # FIXME we should do this as a struct codec?
# dict_schema = permissive_json_schema()
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
# {"type": "integer"})
# schema = tskit.MetadataSchema(dict_schema)
# tables.nodes.schema = schema

# Add metadata for any non-PC node, pointing to the original ancestor
metadata = []
ancestor = 0
Expand Down Expand Up @@ -1940,16 +1944,20 @@ def get_ancestors_tables(self):
len(tables.sites),
)
)
return tables

def store_output(self):
tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)
# We decided to use a permissive schema for the metadata, for flexibility
dict_schema = tskit.MetadataSchema.permissive_json().schema
dict_schema = add_to_schema(
dict_schema, "ancestor_data_id", node_ancestor_data_id_metadata_definition
)
tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema)

if self.num_ancestors > 0:
tables = self.get_ancestors_tables()
else:
# Allocate an empty tree sequence.
tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)
self.fill_ancestors_tables(tables)
tables.time_units = self.time_units
return tables.tree_sequence()

Expand Down Expand Up @@ -2385,6 +2393,12 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
logger.debug("Building augmented ancestors tree sequence")
tsb = self.tree_sequence_builder
tables = self.ancestors_ts_tables.copy()
dict_schema = tables.nodes.metadata_schema.schema
assert dict_schema is not None
dict_schema = add_to_schema(
dict_schema, "sample_data_id", node_sample_data_id_metadata_definition
)
tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema)

flags, times = tsb.dump_nodes()
s = 0
Expand All @@ -2395,9 +2409,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
tables.nodes.add_row(
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
time=times[j],
metadata=_encode_raw_metadata(
{"sample_data_id": int(sample_indexes[s])}
),
metadata={"sample_data_id": int(sample_indexes[s])},
)
s += 1
else:
Expand Down

0 comments on commit a41d838

Please sign in to comment.