Skip to content

Commit

Permalink
update vignettes and website
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilianPi committed Mar 14, 2024
1 parent fb84552 commit c8637cc
Show file tree
Hide file tree
Showing 102 changed files with 2,456 additions and 2,178 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* burnin parameter
* multivariate probit model
* X and Y support (alternative interface)
* negative binomial distribution

## Minor changes
* Improved vignette
Expand Down
30 changes: 22 additions & 8 deletions R/dnn.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
#'
#' @details
#'
#' # Activation functions
#'
#' Supported activation functions: "relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu", "softplus", "celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh", "tanhshrink", "softshrink", "hardshrink", "log_sigmoid"
#'
#' # Loss functions / Likelihoods
#'
#' We support loss functions and likelihoods for different tasks:
Expand All @@ -47,6 +51,7 @@
#' | gaussian | Normal likelihood | Regression, residual error is also estimated (similar to `stats::lm()`) |
#' | binomial | Binomial likelihood | Classification/Logistic regression, mortality|
#' | poisson | Poisson likelihood |Regression, count data, e.g. species abundances|
#' | nbinom | Negative binomial likelihood | Regression, count data with dispersion parameter |
#' | mvp | multivariate probit model | joint species distribution model, multi species (presence absence) |
#'
#' # Training and convergence of neural networks
Expand Down Expand Up @@ -162,12 +167,10 @@
dnn <- function(formula = NULL,
data = NULL,
hidden = c(50L, 50L),
activation = c("relu", "leaky_relu", "tanh", "elu", "rrelu", "prelu", "softplus",
"celu", "selu", "gelu", "relu6", "sigmoid", "softsign", "hardtanh",
"tanhshrink", "softshrink", "hardshrink", "log_sigmoid"),
activation = "selu",
bias = TRUE,
dropout = 0.0,
loss = c("mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson", "mvp"),
loss = c("mse", "mae", "softmax", "cross-entropy", "gaussian", "binomial", "poisson", "mvp", "nbinom"),
validation = 0,
lambda = 0.0,
alpha = 0.5,
Expand All @@ -190,11 +193,9 @@ dnn <- function(formula = NULL,
X = NULL,
Y = NULL) {

if(!inherits(activation, "tune")) {

}

out <- list()

class(out) <- "citodnn"

tuner = check_hyperparameters(hidden = hidden ,
Expand All @@ -212,7 +213,20 @@ dnn <- function(formula = NULL,

if(!is.function(loss) & !inherits(loss,"family")){
loss <- match.arg(loss)

if((device == "mps") & (loss %in% c("poisson", "nbinom"))) {
message("`poisson` or `nbinom` are not yet supported for `device=mps`, switching to `device=cpu`")
device = "cpu"
}
}

if(inherits(loss,"family")) {
if((device == "mps") & (loss$family %in% c("poisson", "nbinom"))) {
message("`poisson` or `nbinom` are not yet supported for `device=mps`, switching to `device=cpu`")
device = "cpu"
}
}

device_old = device
device = check_device(device)
tmp_data = get_X_Y(formula, X, Y, data)
Expand Down Expand Up @@ -268,7 +282,7 @@ dnn <- function(formula = NULL,


loss_obj <- get_loss(loss, device = device, X = X, Y = Y)
if(!is.null(loss_obj$parameter)) loss_obj$parameter <- list(paramter = loss_obj$parameter)
if(!is.null(loss_obj$parameter)) loss_obj$parameter <- list(parameter = loss_obj$parameter)
if(!is.null(custom_parameters)){
if(!inherits(custom_parameters,"list")){
warning("custom_parameters has to be list")
Expand Down
40 changes: 40 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,46 @@ get_loss <- function(loss, device = "cpu", X = NULL, Y = NULL) {
Eprob = (exp(logprob-maxlogprob))$mean(dim = 1)
return((-log(Eprob) - maxlogprob)$mean())
}
} else if(loss == "nbinom") {

if(is.matrix(Y)) out$parameter = torch::torch_tensor(rep(0.5, ncol(Y)), requires_grad=TRUE, device = device)
else out$parameter = torch::torch_tensor(0.5, requires_grad=TRUE, device = device)
out$invlink <- function(a) torch::torch_exp(a)
out$link <- function(a) log(as.matrix(a))
out$parameter_link = function() as.numeric(1.0/(torch::nnf_softplus(out$parameter)+0.0001))
out$simulate = function(pred) {
theta_tmp = out$parameter_link()
probs = 1.0 - theta_tmp/(theta_tmp + pred)
total_count = theta_tmp

if(is.matrix(pred)) {
sim = sapply(1:ncol(pred), function(i) {
logits = log(probs[,i]) - log1p(-probs[,i])
stats::rpois(length(logits), exp(-logits))
return( stats::rpois(length(logits), stats::rgamma(length(logits),total_count[i], exp(- logits ))) )
})
} else {
logits = log(probs) - log1p(-probs)
stats::rpois(length(pred), exp(-logits))
sim = stats::rpois(length(pred), stats::rgamma(length(pred),total_count, exp(- logits )))
}
return(sim)
}

out$loss = function(pred, true) {
eps = 0.0001
pred = pred$exp()
theta_tmp = 1.0/(torch::nnf_softplus(out$parameter)+eps)
probs = torch::torch_clamp(1.0 - theta_tmp/(theta_tmp+pred)+eps, 0.0, 1.0-eps)
total_count = theta_tmp
value = true
logits = torch::torch_log(probs) - torch::torch_log1p(-probs)
log_unnormalized_prob <- total_count * torch::torch_log(torch::torch_sigmoid(-logits)) + value * torch::torch_log(torch::torch_sigmoid(logits))
log_normalization <- -torch::torch_lgamma(total_count + value) + torch::torch_lgamma(1 + value) + torch::torch_lgamma(total_count)
log_normalization <- torch::torch_where(total_count + value == 0, torch::torch_tensor(0, dtype = log_normalization$dtype, device = out$parameter$device), log_normalization)
return( - (log_unnormalized_prob - log_normalization)$mean())
}

}
else{
cat( "unidentified loss \n")
Expand Down
Loading

0 comments on commit c8637cc

Please sign in to comment.