Skip to content

Commit

Permalink
NTI SD: Initial work on faster NTI SD
Browse files Browse the repository at this point in the history
sum_subtract is a work in progress,
possibly to do with the bnok ratios.

Updates #790
  • Loading branch information
shawnlaffan committed May 17, 2021
1 parent 3129170 commit db491eb
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 2 deletions.
127 changes: 125 additions & 2 deletions lib/Biodiverse/Tree.pm
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use Scalar::Util qw /looks_like_number blessed/;
use List::MoreUtils qw /first_index/;
use List::Util qw /sum min max uniq any/;
use Ref::Util qw { :all };
use Sort::Key qw /keysort/;
use Sort::Key qw /keysort rnkeysort/;
use Sort::Key::Natural qw /natkeysort/;
use POSIX qw /floor ceil/;

Expand Down Expand Up @@ -2690,7 +2690,9 @@ sub AUTOLOAD {
if ( defined $root_node and $root_node->can($method) ) {

#print "[TREE] Using AUTOLOADER method $method\n";
return $root_node->$method(@_);
my $retval = eval {$root_node->$method(@_)};
croak "$EVAL_ERROR in method $method" if $EVAL_ERROR;
return $retval;
}
else {
Biodiverse::NoMethod->throw(
Expand Down Expand Up @@ -3203,7 +3205,128 @@ sub get_nti_expected_sd {
= keysort {$_->get_name}
grep {!$_->is_root_node}
$self->get_node_refs;

\my %by_se = $self->get_cached_value_dor_set_default_aa (NODE_NTI_LEN_CACHE => {});
if (!keys %by_se) {
foreach my $node (@node_refs) {
my $te = $node->get_terminal_element_count;
$by_se{$te} += $node->get_length * $te;
}
}

# names from PhyloMeasures
#_compute_subtree_sums(sum_subtree, sum_subtract);
my $sum_subtree = $self->get_cached_value ("NTI_SUM_SUBTREE");
my $sum_subtract = $self->get_cached_value ("NTI_SUM_SUBTRACT");

if (!defined $sum_subtree) {
($sum_subtree, $sum_subtract) = $self->get_nti_sd_subtree_bits(
ln_fac_array => \@ln_fac_arr, # underhanded
sample_count => $r,
);
## two-pass approach for now
#my @nodes_by_depth = rnkeysort {$_->get_depth} @node_refs;
#my %tip_pr_vals;
#foreach my $node (@nodes_by_depth) {
# my $name = $node->get_name;
# my $length = $len_cache{$name} //= $node->get_length;
# my $tip_count = $tip_count_cache{$name} //= $node->get_terminal_element_count;
# # sum of probs is the sum of the child probs
# $tip_pr_vals{$name} = 0;
# #$tip_pr_vals{$name} = $length * $tip_count;
# foreach my $child ($node->get_children) {
# my $ch_name = $child->get_name;
# $tip_pr_vals{$name} += $tip_pr_vals{$ch_name};
# }
#}
#my %sum_of_products;
#foreach my $node (@nodes_by_depth) {
# my $sum_of_pr
# = $sum_of_products{$node->get_parent->get_name}
# // 0;
# my $se
# = $tip_count_cache{$name1}
# //= $node->get_terminal_element_count;
# my $bnok_ratio = $s - $se - $r + 1 > 0
# ? $ln_fac_arr[$s-$se]
# - ( $ln_fac_arr[$r-1]
# + $ln_fac_arr[$s - $se - $r + 1]
# )
# - $bnok_sr
# : -$bnok_sr;
# foreach my $child ($node->get_children) {
# $sum_of_products{$node->get_name} += $sum_of_pr;
# }
#}
}

# names from PhyloMeasures
my ($sum_self, $sum_self_third_case,
$sum_same_class_third_case, $sum_third_case);

foreach my $node (@node_refs) {
my $name = $node->get_name;
my $len = $len_cache{$name} //= $node->get_length;
my $se = $tip_count_cache{$name} //= $node->get_terminal_element_count;
#my $anc1 = $ancestor_cache{$name} //= $node->get_path_lengths_to_root_node_aa;
my $bnok_ratio
= $s - $se - $r + 1 > 0
? $ln_fac_arr[$s-$se]
- ( $ln_fac_arr[$r-1]
+ $ln_fac_arr[$s - $se - $r + 1]
)
- $bnok_sr
: -$bnok_sr;

$sum_self += $se * exp ($bnok_ratio) * $len ** 2;

$sum_self_third_case = $len ** 2 * $se ** 2 * exp ($bnok_ratio);
# need to check two_edge_pr ind and anc with same input
# - should be the same result
}

##my @lens = sort {$a <=> $b} keys %by_length;
#foreach my $se (sort {$a <=> $b} keys %by_se) {
# my $val = $by_length{$len1};
#
# my $bnok_ratio
# = $s - $se - $r + 1 > 0
# ? $ln_fac_arr[$s-$se]
# - ( $ln_fac_arr[$r-1]
# + $ln_fac_arr[$s - $se - $r + 1]
# )
# - $bnok_sr
# : -$bnok_sr;
#
# $sum_same_class_third_case += ($val**2) * exp ($bnok_ratio);
#
# foreach my $sl (keys %by_length) {
# my $bnok_ratio
# = $s - $sl - $se - $r + 2 > 0
# ? $ln_fac_arr[$s-$sl-$se]
# - ( $ln_fac_arr[$r-2]
# + $ln_fac_arr[$s - $se - $sl - $r + 2]
# )
# - $bnok_sr
# : -$bnok_sr;
# $sum_third_case += ($val**2) * exp ($bnok_ratio);
# }
#}
#
#$sum_same_class_third_case
# = (($sum_same_class_third_case - $sum_self_third_case) / 2
# + $sum_self_third_case;
##sum_same_class_third_case = ((sum_same_class_third_case - sum_self_third_case)/Number_type(2.0))
## + sum_self_third_case;
#
#
#$sum_third_case += $sum_same_class_third_case;
##sum_third_case += sum_same_class_third_case;
#my $total_sum = 2 * ($sum_third_case - $sum_subtract + $sum_subtree)) - $sum_self;
#$total_sum = 4 * $total_sum / ($s ** 2);
#


my $n_nodes = @node_refs;
my $progress
= $n_nodes > 1000
Expand Down
188 changes: 188 additions & 0 deletions lib/Biodiverse/TreeNode.pm
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,194 @@ sub get_nri_all_weights {
return $value;
}

# eventually the NTI sd calcs can be moved back into Bd::Tree,
# but for now we need a recursive approach
my $nti_sum_of_products_cache_name = 'NTI_SUM_OF_PRODUCTS_BELOW';

sub get_nti_sum_of_products_below {
my ($self) = @_;

my $value
= $self->get_cached_value ($nti_sum_of_products_cache_name);

return $value if defined $value;

# run through the tree, calculating the sums
# no need to worry about LCA in this case
$self->get_root_node->_calc_nti_sum_of_products_below;

return $self->get_cached_value ($nti_sum_of_products_cache_name);
}

# a bit messy - could do externally adding to parents
# includees self now, need to change name
sub _calc_nti_sum_of_products_below {
my ($self) = @_;

my $sum = 0;
foreach my $child ($self->get_children) {
$sum += $child->_calc_nti_sum_of_products_below;
}

$sum += $self->get_length * $self->get_terminal_element_count;

# cache the values below
$self->set_cached_value ($nti_sum_of_products_cache_name => $sum);

# and return ourselves plus the sum from below
return $sum;
}


sub get_nti_sd_subtree_bits {
my ($self, %args) = @_;
my $r = $args{sample_count} // croak "need sample_counts argument";
my $cache_name = 'NTI_SUBTREE_BITS';
my $hash = $self->get_cached_value_dor_set_default_aa ($cache_name => {});
my $vals = $hash->{$r};

if (!defined $vals) {
# we need to do the whole tree
my $root = $self->get_root_node;

# ensure we have the sum of products ready for us
$root->get_nti_sum_of_products_below;

# work from the LCA if it is not the root
while ($root->get_child_count == 1) {
my $child_arr = $root->get_children;
$root = $child_arr->[0];
}
# don't calc the root
my ($sum_subtree, $sum_subtract);
foreach my $child ($root->get_children) {
my ($s1, $s2) = $child->_calc_nti_sd_subtree_bits (%args);
$sum_subtree += $s1;
$sum_subtract += $s2;
}
$hash->{$r} = $vals = [$sum_subtree, $sum_subtract];
}

return wantarray ? @$vals : $vals;
}

sub _calc_nti_sd_subtree_bits {
my ($self, %args) = @_;

# requiring this to be passed is dirty and underhanded
# it should be in Bd::Common
\my @ln_fac_arr = $args{ln_fac_array} // croak 'need the ln_fac_array';
my $r = $args{r} // $args{sample_count};

# horrible that we pass these through - could just use a map in Bd::Tree
my $sum_subtree = $args{sum_subtree};
my $sum_subtract = $args{sum_subtract};


my $length = $self->get_length;
my $se = $self->get_terminal_element_count;
my $s = $self->get_root_node->get_terminal_element_count;
my $bnok_sr
= $ln_fac_arr[$s]
- ( $ln_fac_arr[$r]
+ $ln_fac_arr[$s - $r]
);

# many var names from PhyloMeasures
my $bnok_ratio = $s - $se - $r + 1 > 0
? $ln_fac_arr[$s-$se]
- ( $ln_fac_arr[$r-1]
+ $ln_fac_arr[$s - $se - $r + 1]
)
- $bnok_sr
: -$bnok_sr;
my $mhyperg = exp $bnok_ratio;

#my ($sum_subtree, $sum_subtract);
#my $_sum_prod; # for debug

foreach my $child ($self->get_children) {
($sum_subtree, $sum_subtract)
= $child->_calc_nti_sd_subtree_bits(
%args,
sum_subtree => $sum_subtree,
sum_subtract => $sum_subtract,
);

my $sum_pr = $child->get_nti_sum_of_products_below;
$sum_subtree += $length * $sum_pr * $mhyperg;

#$_sum_prod += $sum_pr;

# slow - need to cache
my $descendants = $child->get_all_descendants;
foreach my $desc ($child, values %$descendants) {
my $d_len = $desc->get_length;
my $sl = $desc->get_terminal_element_count;

my $bnok_ratio
= $s - $sl - $se - $r + 2 > 0
? $ln_fac_arr[$s-$sl-$se]
- ( $ln_fac_arr[$r-2]
+ $ln_fac_arr[$s - $se - $sl - $r + 2]
)
- $bnok_sr
: $s - $sl - $se > 0
? -$bnok_sr
: -Inf;
$sum_subtract
+= $length * $sl
* $se * $d_len
* exp $bnok_ratio;

if (0 && $r == 2) {
say STDERR sprintf "SUBTRACT: len=%.6f, se=%d, chlen=%f, sl=%d, hypergeom=%.6f, sum_subtract=%.6f",
$length,
$se,
$d_len,
$sl,
exp ($bnok_ratio),
$sum_subtract;
}
}
}

$sum_subtree += ($length ** 2) * $se * $mhyperg;

$bnok_ratio
= $s - $se - $se - $r + 2 > 0
? $ln_fac_arr[$s-$se-$se]
- ( $ln_fac_arr[$r-2]
+ $ln_fac_arr[$s - $se - $se - $r + 2]
)
- $bnok_sr
: -Inf;

$sum_subtract += exp ($bnok_ratio) * ($length ** 2) * ($se ** 2);

if ($r == 2) {
say STDERR sprintf "SNUBTRACT: len=%.6f, se=%d, two_edge_pr=%.6f, sum_subtract=%.6f",
$length,
$se,
exp ($bnok_ratio),
$sum_subtract;
}

my @components = ($sum_subtree, $sum_subtract);

#if ($r == 2) {
# $_sum_prod += $length * $se;
# say STDERR sprintf "XX: %d, len=%.6f, se=%d, sum_subtree=%.6f, sum_subtract=%.6f, sum_prod=%.6f, mhyperg=%.6f",
# $r, $length, $se, $sum_subtree, $sum_subtract, $_sum_prod, $mhyperg;
#}

my $cache_name = 'NTI_SUBTREE_BITS_HASH';
$self->set_cached_value ($cache_name => \@components);

return wantarray ? @components : \@components;
}


sub numerically {$a <=> $b}

1;
Expand Down

0 comments on commit db491eb

Please sign in to comment.