R/stats.R

# Class representing a abstract Stat which is used to calculate values and
# confidence intervals of a matrix in a column by column manner.

Stat <- R6Class("Stat",
    public = list(
        initialize = function(data, ctrl = NULL, alpha = 0.05, average = "mean",
                              range = c(-1, 1)) {

            private$check_param(data = data, ctrl = ctrl, alpha = alpha,
                          average = average, range = range)

            # Save parameters
            private$parameters[["average"]] <- average
            private$parameters[["alpha"]] <- alpha
            private$parameters[["range"]] <- range
      
            private$data <- data
            if (!is.null(ctrl)) {
                private$ctrl <- ctrl
            }

            # Calculate the statistics
            private$statistics <- private$calculate_statistics()
            private$validate_statistics()
        },
        get_statistics = function() {
            private$statistics
        }
    ),
    private = list(
        statistics = data.frame(position = numeric(),
                                value = numeric(),
                                qinf = numeric(),
                                qsup = numeric()),
        parameters = list(),
        data = matrix(),
        ctrl = NULL,
        check_param = function(data, ctrl, alpha, average, range) {
            # Check parameters validity
            if (!is.matrix(data) || sum(!is.na(data) == 0)) {
                stop("data must be a matrix with at least one value")
            }
            if (!is.null(ctrl)) {
                if (!is.matrix(ctrl) || sum(!is.na(ctrl) == 0)) {
                    stop("ctrl must be a matrix with at least one value")
                }
                if (!identical(dim(data), dim(ctrl))) {
                    stop("data and ctrl must be of identical dimensions")
                }
            }
            if (!is.numeric(alpha) || alpha < 0 || alpha > 1) {
                stop("alpha parameter must be a numeric between 0 and 1")
            }
            if (! average %in% c("mean", "median")) {
                stop("average parameter must be either 'mean' or 'median'")
            }
            if (!is.numeric(range) | length(range) != 2) {
                stop("range parameter must be a numeric of length 2")
            }
        },
        validate_statistics = function() {
            error <- "Stat object is not in a valid state:"

            # Check class
            if (!is(private$statistics, "data.frame")) {
                reason = "statistics must be a data.frame."
                stop(paste(error, reason))
            }
      
            # Check columns
            expected <- c("position", "value", "qinf", "qsup")
            if (!identical(colnames(private$statistics), expected)) {
                reason <- "invalid column names."
                stop(paste(error, reason))
            }
            if (nrow(private$statistics > 0)) {
                if (!all(sapply(private$statistics, function(x) { is(x, "numeric")}))) {
                    reason <- "invalid column classes."
                    stop(paste(error, reason))
                }
            }
        }
    )
)

# Class representing basic Stat. The value is either the mean or the median
# (selected by the user) and qsup and qinf represent the range of the data in
# the alpha/2:1-alpha/2 percentile.

Basic_Stat <- R6Class("Basic_Stat",
    inherit = Stat,
    public = list(
    ),
    private = list(
        calculate_statistics = function() {
            # Fetch relevant params
            alpha <- private$parameters[["alpha"]]
            data <- private$data
            average <- private$parameters[["average"]]
            range <- private$parameters[["range"]]

            # Prepare function
            calculate_statistic <- function(column_values) {
                if (average == "mean") {
                    value <- mean(column_values)
                } else {
                    value <- median(column_values)
                }
                qinf <- unname(quantile(column_values, alpha/2))
                qsup <- unname(quantile(column_values, 1-(alpha/2)))
                c(value = value, qinf = qinf, qsup = qsup)
            }

            # Calculate results
            position <- seq(range[1], range[2], length.out = ncol(data))
            data <- split(data, rep(1:ncol(data), each = nrow(data)))
            res <- lapply(data, calculate_statistic)
            res <- data.frame(do.call(rbind, res))
            cbind(position, res, row.names = NULL)
        }
    )
)

# Class representing the bootstrap stat.

Bootstrap_Stat <- R6Class("Bootstrap_Stat",
    inherit = Stat,
    public = list(
        initialize = function(data, ctrl = NULL, alpha = 0.05, average = "mean",
                              range = c(-1, 1), sample_count = 1000,
                              sample_size = NA, debug = FALSE) {

            # Check parameters validity
            super$check_param(data = data, ctrl = ctrl, alpha = alpha,
                              average = average, range = range)
            if (!is.numeric(sample_count)
                || as.integer(sample_count) != sample_count
                || sample_count < 1) {
                stop("sample_count must be a positive integer.")
            }
            if (!is.na(sample_size)) {
                if (!is.numeric(sample_size)
                    || as.integer(sample_size) != sample_size
                    || sample_size < 1) {
                    stop("sample_size must be a positive integer.")
                }
            }
            if (!is.logical(debug)) {
                stop("debug must be TRUE or FALSE.")
            }

            # Save parameters
            private$parameters[["sample_count"]] <- sample_count
            if (is.na(sample_size)) {
                sample_size <- nrow(data)
            }
            private$parameters[["sample_size"]] <- sample_size
            private$parameters[["debug"]] <- debug

            # Initialize and calculate statistic
            super$initialize(data = data, ctrl = ctrl, alpha = alpha,
                             average = average, range = range)
        },
        get_statistics = function() {
            if (private$parameters[["debug"]]) {
                list(statistics = private$statistics,
                     values = private$values,
                     replicates = private$replicates)
            } else {
                private$statistics
            }
        }
    ),
    private = list(
        replicates = list(),
        values = list(),
        calculate_statistics = function() {
            # Fetch relevant params
            data <- private$data
            ctrl <- private$ctrl
            range <- private$parameters[["range"]]

            # Calculate results
            position <- seq(range[1], range[2], length.out = ncol(data))
            if (!is.null(ctrl)) {
                res <- lapply(split(data, rep(1:ncol(data), each = nrow(data))),
                              private$calculate_statistic, ctrl = ctrl)
            } else {
                res <- lapply(split(data, rep(1:ncol(data), each = nrow(data))),
                              private$calculate_statistic)
            }
            res <- data.frame(do.call(rbind, res))
            cbind(position, res, row.names = NULL)
        },
        # Calculate the statistic for a single column
        calculate_statistic = function(column_values, ctrl = NULL) {
            # Check param
            stopifnot(is.numeric(column_values))
            stopifnot(length(column_values) > 0)
            if (!is.null(ctrl)) {
                stopifnot(is.numeric(ctrl))
                stopifnot(length(ctrl) > 0)
            }

            # Fetch relevant parameters
            alpha <- private$parameters[["alpha"]]
            average <- private$parameters[["average"]]

            # Calculate result
            # Note: generate_draw_values remove ctrl values
            replicates <- private$generate_draw_values(column_values, ctrl)
            values <- private$calculate_replicate_values(replicates)
            if (private$parameters[["debug"]]) {
                i <- length(private$replicates) + 1
                private$replicates[[i]] <- replicates
                private$values[[i]] <- values
            }

            res <- quantile(values, c(0.5, alpha/2, 1-(alpha/2)))
            names(res) <- c("value", "qinf", "qsup")
            res
        },
        generate_draw_values = function(column_values, ctrl = NULL) {
            sample_count <- private$parameters[["sample_count"]]
            sample_size <- private$parameters[["sample_size"]]

            sample_data <- function() {
                    column_values[sample(seq_along(column_values),
                                         sample_size * sample_count,
                                         replace=TRUE)]
            }
            matrix(sample_data(), ncol = sample_count)
        },
        calculate_replicate_values = function(replicates) {
            average <- private$parameters[["average"]]
            if (identical(average, "mean")) {
                colMeans(replicates)
            } else {
                matrixStats::colMedians(replicates)
            }
        }
    )
)
CharlesJB/metagene documentation built on July 11, 2021, 11:48 a.m.