Skip to content

Commit

Permalink
Numpy2 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong authored and mergify[bot] committed Jul 24, 2024
1 parent 6b2116d commit 599920d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
10 changes: 6 additions & 4 deletions tsinfer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,25 @@
"""
import enum

import numpy as np

C_ENGINE = "C"
PY_ENGINE = "P"


# TODO Change these to use the enum.IntFlag class

# Bit 16 is set in node flags when they have been created by path compression.
NODE_IS_PC_ANCESTOR = 1 << 16
NODE_IS_PC_ANCESTOR = np.uint32(1 << 16)
# Bit 17 is set in node flags when they have been created by shared recombination
# breakpoint
NODE_IS_SRB_ANCESTOR = 1 << 17
NODE_IS_SRB_ANCESTOR = np.uint32(1 << 17)
# Bit 18 is set in node flags when they are samples inserted to augment existing
# ancestors.
NODE_IS_SAMPLE_ANCESTOR = 1 << 18
NODE_IS_SAMPLE_ANCESTOR = np.uint32(1 << 18)
# Bit 20 is set in node flags when they are samples not at time zero in the sampledata
# file
NODE_IS_HISTORICAL_SAMPLE = 1 << 20
NODE_IS_HISTORICAL_SAMPLE = np.uint32(1 << 20)

# What type of inference have we done at a site?
INFERENCE_NONE = "none"
Expand Down
4 changes: 3 additions & 1 deletion tsinfer/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,9 @@ def extract_ancestors(samples, ts):
index = tables.nodes.flags == tskit.NODE_IS_SAMPLE
flags[index] = tskit.NODE_IS_SAMPLE
index = tables.nodes.flags != tskit.NODE_IS_SAMPLE
flags[index] = np.bitwise_and(tables.nodes.flags[index], ~tskit.NODE_IS_SAMPLE)
flags[index] = np.bitwise_and(
tables.nodes.flags[index], ~flags.dtype.type(tskit.NODE_IS_SAMPLE)
)

tables.nodes.set_columns(
flags=flags,
Expand Down
2 changes: 1 addition & 1 deletion tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,7 +1271,7 @@ def sites_time(self):
@sites_time.setter
def sites_time(self, value):
self._check_edit_mode()
self.data["sites/time"][:] = np.array(value, dtype=np.float64, copy=False)
self.data["sites/time"][:] = np.asarray(value, dtype=np.float64)

@property
def sites_alleles(self):
Expand Down
4 changes: 2 additions & 2 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def count_pc_ancestors(flags):
Returns the number of values in the specified array which have the
NODE_IS_PC_ANCESTOR set.
"""
flags = np.array(flags, dtype=np.uint32, copy=False)
flags = np.asarray(flags, dtype=np.uint32)
return np.sum(is_pc_ancestor(flags))


Expand All @@ -157,7 +157,7 @@ def count_srb_ancestors(flags):
Returns the number of values in the specified array which have the
NODE_IS_SRB_ANCESTOR set.
"""
flags = np.array(flags, dtype=np.uint32, copy=False)
flags = np.asarray(flags, dtype=np.uint32)
return np.sum(np.bitwise_and(flags, constants.NODE_IS_SRB_ANCESTOR) != 0)


Expand Down

0 comments on commit 599920d

Please sign in to comment.