Skip to content

Commit

Permalink
Set permissive schema on node metadata
Browse files Browse the repository at this point in the history
Also no longer need to fill empty json metadata with `{}` in tests/util.py
  • Loading branch information
hyanwong authored and benjeffery committed Jul 17, 2024
1 parent 141f0c7 commit 8a96427
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 37 deletions.
27 changes: 14 additions & 13 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,16 @@ def test_from_standard_tree_sequence(self):
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
# Unless inference is perfect, internal nodes may differ, but sample nodes
# should be identical
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
assert ts.node(n1) == ts_inferred.node(n2)
for u1, u2 in zip(ts.samples(), ts_inferred.samples()):
# NB - flags might differ if e.g. the node is a historical sample
# but original ones should be maintained
n1 = ts.node(u1)
n2 = ts.node(u2)
assert (n1.flags & n2.flags) == n1.flags # n1.flags is subset of n2.flags
assert n1.time == n2.time
assert n1.population == n2.population
assert n1.individual == n2.individual
assert tsutil.json_metadata_is_subset(n1.metadata, n2.metadata)
# Sites can have metadata added by the inference process, but inferred site
# metadata should always include all the metadata in the original ts
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
Expand Down Expand Up @@ -1906,7 +1914,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
ancestors_time = ancestor_data.ancestors_time[:]
num_ancestor_nodes = 0
for n in ancestors_ts.nodes():
md = json.loads(n.metadata) if n.metadata else {}
md = n.metadata
if tsinfer.is_pc_ancestor(n.flags):
assert not ("ancestor_data_id" in md)
else:
Expand Down Expand Up @@ -2966,7 +2974,6 @@ def verify(self, ts):
last_node = ts1.node(ts1.num_nodes - 1)
assert np.max(ts1.tables.nodes.time) == last_node.time
md = last_node.metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md.get("ancestor_data_id", None) != 0

# When not post processing and there is no path compression,
Expand All @@ -2977,7 +2984,6 @@ def verify(self, ts):
first_node = ts2.node(0)
assert np.max(ts2.tables.nodes.time) == first_node.time
md = first_node.metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 0

@pytest.mark.parametrize("simp", [True, False])
Expand Down Expand Up @@ -3015,7 +3021,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
oldest_parent_id = ts_unsimplified.edge(-1).parent
assert oldest_parent_id == 0
md = ts_unsimplified.node(oldest_parent_id).metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 0

# Post processing removes ancestor_data_id 0
Expand All @@ -3024,7 +3029,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
oldest_parent_id = ts.edge(-1).parent
assert np.sum(ts.tables.nodes.time == ts.node(oldest_parent_id).time) == 1
md = ts.node(oldest_parent_id).metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 1

ts = tsinfer.post_process(
Expand All @@ -3036,7 +3040,6 @@ def test_standalone_post_process(self, medium_sd_fixture):
for tree in ts.trees():
roots.add(tree.root)
md = ts.node(tree.root).metadata
md = json.loads(md.decode()) # At the moment node metadata has no schema
assert md["ancestor_data_id"] == 1
assert len(roots) > 1

Expand Down Expand Up @@ -3815,16 +3818,15 @@ def verify_augmented_ancestors(
node = t2.nodes[m + j]
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
assert node.time == 1
metadata = json.loads(node.metadata.decode())
assert node_id == metadata["sample_data_id"]
assert node_id == node.metadata["sample_data_id"]

t2.nodes.truncate(len(t1.nodes))
# Adding and subtracting 1 can lead to small diffs, so we compare
# the time separately.
t2.nodes.time -= 1.0
assert np.allclose(t2.nodes.time, t1.nodes.time)
t2.nodes.time = t1.nodes.time
assert t1.nodes == t2.nodes
t1.nodes.assert_equals(t2.nodes, ignore_metadata=True)
if not path_compression:
# If we have path compression it's possible that some older edges
# will be compressed out.
Expand Down Expand Up @@ -3966,8 +3968,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
num_sample_ancestors = 0
for node in final_ts.nodes():
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
metadata = json.loads(node.metadata.decode())
assert metadata["sample_data_id"] in subset
assert node.metadata["sample_data_id"] in subset
num_sample_ancestors += 1
assert expected_sample_ancestors == num_sample_ancestors
tsinfer.verify(samples, final_ts.simplify())
Expand Down
4 changes: 2 additions & 2 deletions tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def add_default_schemas(ts):
tables.populations[pop.id] = pop
tables.individuals.metadata_schema = schema
assert len(tables.individuals.metadata) == 0
tables.individuals.packset_metadata([b"{}"] * ts.num_individuals)
tables.sites.metadata_schema = schema
assert len(tables.sites.metadata) == 0
tables.sites.packset_metadata([b"{}"] * ts.num_sites)
tables.nodes.metadata_schema = schema
assert len(tables.nodes.metadata) == 0
return tables.tree_sequence()


Expand Down
56 changes: 34 additions & 22 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@
],
}

node_ancestor_data_id_metadata_definition = {
"description": (
"The ID of the tsinfer ancestor data node from which this node is derived."
),
"type": "number",
}

node_sample_data_id_metadata_definition = {
"description": (
"The ID of the tsinfer sample data node from which this node is derived. "
"Only present for nodes in which historical samples are treated as ancestors."
),
"type": "number",
}


class LMDBCache:
def __init__(self, lmdb):
Expand Down Expand Up @@ -1879,7 +1894,7 @@ def match_ancestors(self):
logger.info("Finished ancestor matching")
return ts

def get_ancestors_tables(self):
def fill_ancestors_tables(self, tables):
"""
Return the ancestors tree sequence tables. Only inference sites are included in
this tree sequence. All nodes have the sample flag bit set, and if a node
Expand All @@ -1888,21 +1903,10 @@ def get_ancestors_tables(self):
logger.debug("Building ancestors tree sequence")
tsb = self.tree_sequence_builder

tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)

flags, times = tsb.dump_nodes()
pc_ancestors = is_pc_ancestor(flags)
tables.nodes.set_columns(flags=flags, time=times)

# # FIXME we should do this as a struct codec?
# dict_schema = permissive_json_schema()
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
# {"type": "integer"})
# schema = tskit.MetadataSchema(dict_schema)
# tables.nodes.schema = schema

# Add metadata for any non-PC node, pointing to the original ancestor
metadata = []
ancestor = 0
Expand Down Expand Up @@ -1940,16 +1944,20 @@ def get_ancestors_tables(self):
len(tables.sites),
)
)
return tables

def store_output(self):
tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)
# We decided to use a permissive schema for the metadata, for flexibility
dict_schema = tskit.MetadataSchema.permissive_json().schema
dict_schema = add_to_schema(
dict_schema, "ancestor_data_id", node_ancestor_data_id_metadata_definition
)
tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema)

if self.num_ancestors > 0:
tables = self.get_ancestors_tables()
else:
# Allocate an empty tree sequence.
tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)
self.fill_ancestors_tables(tables)
tables.time_units = self.time_units
return tables.tree_sequence()

Expand Down Expand Up @@ -2385,6 +2393,12 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
logger.debug("Building augmented ancestors tree sequence")
tsb = self.tree_sequence_builder
tables = self.ancestors_ts_tables.copy()
dict_schema = tables.nodes.metadata_schema.schema
assert dict_schema is not None
dict_schema = add_to_schema(
dict_schema, "sample_data_id", node_sample_data_id_metadata_definition
)
tables.nodes.metadata_schema = tskit.MetadataSchema(dict_schema)

flags, times = tsb.dump_nodes()
s = 0
Expand All @@ -2395,9 +2409,7 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
tables.nodes.add_row(
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
time=times[j],
metadata=_encode_raw_metadata(
{"sample_data_id": int(sample_indexes[s])}
),
metadata={"sample_data_id": int(sample_indexes[s])},
)
s += 1
else:
Expand Down

0 comments on commit 8a96427

Please sign in to comment.