Skip to content

Commit

Permalink
Simplex uses Stan's stick breaking transform and fix Dirichlet deriva…
Browse files Browse the repository at this point in the history
…tive
  • Loading branch information
4ment committed Mar 15, 2024
1 parent dade424 commit 9cd695e
Show file tree
Hide file tree
Showing 14 changed files with 318 additions and 247 deletions.
14 changes: 12 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ SET( LIBRARY_NAME "PHYC" )

SET( ${LIBRARY_NAME}_MAJOR_VERSION 2 )
SET( ${LIBRARY_NAME}_MINOR_VERSION 0 )
SET( ${LIBRARY_NAME}_PATCH_LEVEL 1-dev )
SET( ${LIBRARY_NAME}_PATCH_LEVEL 0 )


#####################################################################
Expand Down Expand Up @@ -183,12 +183,21 @@ ENDIF()
# Config header file
#####################################################################

execute_process(
COMMAND git tag --points-at HEAD
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_HEAD_TAG
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)

# Get the current working branch
execute_process(
COMMAND git rev-parse --abbrev-ref HEAD
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_BRANCH
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)

# Get the latest abbreviated commit hash of the working branch
Expand All @@ -197,6 +206,7 @@ execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_COMMIT_HASH
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)

configure_file (
Expand Down Expand Up @@ -338,7 +348,7 @@ IF(NOT DISABLE_GSL)
target_link_libraries(phyc GSL::gsl GSL::gslcblas)
ENDIF()

IF(NOT LIBRARY_ONLY)
IF(NOT LIBRARY_ONLY AND NOT DISABLE_GSL)
add_executable(physher src/physher.c)
target_link_libraries(physher phyc)

Expand Down
140 changes: 0 additions & 140 deletions examples/fluA/checkpoint.jon

This file was deleted.

2 changes: 1 addition & 1 deletion src/modelAveraging.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int main(int argc, char* argv[]){

if ( args_contains(argc, argv, "-c") ) {
bool success = false;
nclasses = args_get_int(argc, argv, "-c", &success );
nclasses = args_get_int2(argc, argv, "-c", &success );
if( !success || nclasses <= 0 ){
error("Could not read the number of classes [-c]");
}
Expand Down
1 change: 1 addition & 0 deletions src/phyc/PhyCConfig.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@

#define GIT_BRANCH "@GIT_BRANCH@"
#define GIT_COMMIT_HASH "@GIT_COMMIT_HASH@"
#define GIT_HEAD_TAG "@GIT_HEAD_TAG@"

#endif
50 changes: 49 additions & 1 deletion src/phyc/args.c
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ int * args_get_pint( int argc, char* argv[], const char flag[] ){
return option;
}

int args_get_int( int argc, char* argv[], const char flag[], bool *success ){
int args_get_int2( int argc, char* argv[], const char flag[], bool *success ){
int i = 0;
char *str = NULL;

Expand All @@ -446,6 +446,54 @@ int args_get_int( int argc, char* argv[], const char flag[], bool *success ){
return 0;
}

int args_get_int( int argc, char* argv[], const char flag[], int defaultv ){
int i = 0;
char *str = NULL;

for ( ; i < argc; i++) {
if ( strncmp(argv[i], flag, strlen(flag)) == 0 ) {
str = NULL;
if( strlen(argv[i]) > strlen(flag) ){
str = argv[i]+strlen(flag);
}
else if( i+1 < argc ){
str = argv[i+1];
}

if( str != NULL && isInt(str) ){
return atoi( str );
}
break;
}

}
return defaultv;
}

long args_get_long( int argc, char* argv[], const char flag[], long defaultv ){
int i = 0;
char *str = NULL;

for ( ; i < argc; i++) {
if ( strncmp(argv[i], flag, strlen(flag)) == 0 ) {
str = NULL;
if( strlen(argv[i]) > strlen(flag) ){
str = argv[i]+strlen(flag);
}
else if( i+1 < argc ){
str = argv[i+1];
}

if( str != NULL && isInt(str) ){
return atoi( str );
}
break;
}

}
return defaultv;
}

double * args_get_pdouble( int argc, char* argv[], const char flag[] ){
double *option = NULL;
char *str = NULL;
Expand Down
6 changes: 5 additions & 1 deletion src/phyc/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ char * args_get_string( int argc, char* argv[], const char flag[] );

int * args_get_pint( int argc, char* argv[], const char flag[] );

int args_get_int( int argc, char* argv[], const char flag[], bool *success );
int args_get_int2( int argc, char* argv[], const char flag[], bool *success );

int args_get_int( int argc, char* argv[], const char flag[], int defaultv );

long args_get_long( int argc, char* argv[], const char flag[], long defaultv );

double * args_get_pdouble( int argc, char* argv[], const char flag[] );

Expand Down
30 changes: 28 additions & 2 deletions src/phyc/distdirichlet.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@ double DistributionModel_log_flat_dirichlet_with_values(DistributionModel* dm, c
}

double DistributionModel_dlog_flat_dirichlet(DistributionModel* dm, const Parameter* p){
return 0.0;
for (int i = 0; i < Parameters_count(dm->x); i++) {
if( p == Parameters_at(dm->x, i)){
const double* values = dm->simplex->get_values(dm->simplex);
double dlogp = 0;
for (size_t j = i; j < Parameters_count(dm->x); j++) {
dm->simplex->gradient(dm->simplex, i, dm->tempp);
dlogp += dm->tempp[j]/values[j];
}
return dlogp;
}
}
return 0.0;
}

double DistributionModel_d2log_flat_dirichlet(DistributionModel* dm, const Parameter* p){
Expand Down Expand Up @@ -75,10 +86,25 @@ static double DistributionModel_dirichlet_sample_evaluate(DistributionModel* dm)
return logP;
}

// IMPORTANT: The derivative is wrt unconstrained parameter of the simplex
double DistributionModel_dlog_dirichlet(DistributionModel* dm, const Parameter* p){
/*
log pdf(X; \alpha) = \sum_i \alpha_i log(x_i) - log B(\alpha)
d log pdf(X)/dz_k = \sum_i \alpha_i d log(x_i)/dz_k
&= \sum_i \alpha_i d log(x_i)/dx_i dx_i/dz_k
&= \sum_i \alpha_i/x_i dx_i/dz_k
dx_i/dz_k = 0 for i < k
*/
for (int i = 0; i < Parameters_count(dm->x); i++) {
if( p == Parameters_at(dm->x, i)){
return (Parameters_value(dm->parameters[0], i)-1.0)/Parameter_value(p);
const double* values = dm->simplex->get_values(dm->simplex);
double dlogp = 0;
for (size_t j = i; j < Parameters_count(dm->x); j++) {
dm->simplex->gradient(dm->simplex, i, dm->tempp);
dlogp += (Parameters_value(dm->parameters[0], j)-1.0)/values[j] * dm->tempp[j];
}
return dlogp;
}
}
return 0;
Expand Down
Loading

0 comments on commit 9cd695e

Please sign in to comment.