R/CA.R

Defines functions elbow_method scree_plot ca_coords subset_dims run_cacomp inertia_rows var_rows rm_zeros calc_residuals comp_ft_residuals clip_residuals comp_NB_residuals comp_std_residuals

Documented in ca_coords calc_residuals clip_residuals comp_ft_residuals comp_NB_residuals comp_std_residuals elbow_method inertia_rows rm_zeros run_cacomp scree_plot subset_dims var_rows

#' @include constructor.R
NULL

#' Compute Standardized Residuals
#'
#' @description
#' `comp_std_residuals` computes the standardized residual matrix S based on
#' the Poisson model,
#' which is the basis for correspondence analysis and serves
#' as input for singular value decomposition (SVD).
#'
#' @details
#' Calculates standardized residual matrix S from the proportion matrix P and
#' the expected values E according to \eqn{S = \frac{(P-E)}{sqrt(E)}}.
#'
#' @param mat A numerical matrix or coercible to one by `as.matrix()`.
#' Should have row and column names.
#' @param clip logical. Whether residuals should be clipped if they are
#' higher/lower than a specified cutoff
#' @param cutoff numeric. Residuals that are larger than cutoff or lower than
#' -cutoff are clipped to cutoff.
#'
#' @inherit calc_residuals return
#'
comp_std_residuals <- function(mat, clip = FALSE, cutoff = NULL) {
    stopifnot(is(mat, "matrix") | is(mat, "dgeMatrix") | is(mat, "dgCMatrix"))

    stopifnot(
        "Input matrix does not have any rownames!" = !is.null(rownames(mat))
    )
    stopifnot(
        "Input matrix does not have any colnames!" = !is.null(colnames(mat))
    )

    tot <- sum(mat)
    P <- mat / tot # proportions matrix
    rowm <- Matrix::rowSums(P) # row masses
    colm <- Matrix::colSums(P) # column masses

    E <- rowm %o% colm # expected proportions
    S <- (P - E) / sqrt(E) # standardized residuals

    S[is.na(S)] <- 0

    if (isTRUE(clip)) {
        if (is.null(cutoff)) cutoff <- sqrt(ncol(S) / tot)

        S <- clip_residuals(S, cutoff = cutoff)
    }

    out <- list("S" = S, "tot" = tot, "rowm" = rowm, "colm" = colm)

    return(out)
}


#' Compute Negative-Binomial residuals
#' @description
#' Computes the residuals based on the negative binomial model. By default a
#' theta of 100 is used to capture technical variation.
#'
#' @inheritParams comp_std_residuals
#' @param freq logical. Whether a table of frequencies (as used in CA) should
#' be used.
#' @param theta Overdispersion parameter. By default set to 100 as described in
#' Lause and Berens, 2021 (see references).
#'
#' @references
#' Lause, J., Berens, P. & Kobak, D. Analytic Pearson residuals for
#' normalization of single-cell RNA-seq UMI data. Genome Biol 22, 258 (2021).
#' https://doi.org/10.1186/s13059-021-02451-7
#'
#' @inherit calc_residuals return
comp_NB_residuals <- function(mat,
                              theta = 100,
                              clip = FALSE,
                              cutoff = NULL,
                              freq = TRUE) {
    if (isTRUE(freq)) mat <- mat / sum(mat)

    rowS <- Matrix::rowSums(mat)
    colS <- Matrix::colSums(mat)

    tot <- sum(mat)

    mu <- (rowS %o% colS) / tot

    Z <- (mat - mu) / sqrt(mu + (mu**2) / theta)

    # convert to S:
    # sqrt(tot)/tot*Z (for theta = Inf)

    if (isTRUE(clip)) {
        if (is.null(cutoff)) cutoff <- sqrt(ncol(Z))

        Z <- clip_residuals(Z, cutoff = cutoff)
    }

    return(list("S" = Z, "tot" = tot, "rowm" = rowS, "colm" = colS))
}


#' Perform clipping of residuals
#'
#' @description
#' Clips Pearson or negative-binomial residuals above or below a determined
#' value. For Pearson (Poisson) residuals it is set by default for 1, for NB at
#' sqrt(n).
#'
#' @param S Matrix of residuals.
#' @param cutoff Value above/below which clipping should happen.
#'
#' @references
#' Lause, J., Berens, P. & Kobak, D. Analytic Pearson residuals for
#' normalization of single-cell RNA-seq UMI data. Genome Biol 22, 258 (2021).
#' https://doi.org/10.1186/s13059-021-02451-7
#'
#' @returns
#' Matrix of clipped residuals.
clip_residuals <- function(S, cutoff = sqrt(ncol(S))) {
    S[S > cutoff] <- cutoff
    S[S < -cutoff] <- -cutoff

    return(S)
}



#' Compute Freeman-Tukey residuals
#'
#' @description
#' Computes Freeman-Tukey residuals
#'
#' @inheritParams comp_std_residuals
#'
#' @inherit calc_residuals return
comp_ft_residuals <- function(mat) {
    stopifnot(is(mat, "matrix") | is(mat, "dgeMatrix") | is(mat, "dgCMatrix"))

    N <- sum(mat)
    mat <- mat
    pmat <- mat / N

    row.w <- Matrix::rowSums(pmat)
    col.w <- Matrix::colSums(pmat)

    expectedp <- row.w %*% t(col.w)

    # row.sum <- Matrix::rowSums(mat)
    # col.sum <- Matrix::colSums(mat)
    # expectedx <- row.sum %*% t(col.sum)

    S <- pmat^.5 + (pmat + 1 / N)^.5 - (4 * expectedp + 1 / N)^.5

    return(list("S" = S, "tot" = N, "rowm" = row.w, "colm" = col.w))
}


#' Calculate residuals for Correspondence analysis
#'
#' @description
#' `calc_residuals` provides optional residuals as the basis for Correspondence
#' Analysis
#'
#' @param residuals character string. Specifies which kind of residuals should
#' be calculated. Can be "pearson" (default), "freemantukey" or "NB" for
#' negative-binomial.
#' @inheritParams comp_std_residuals
#'
#' @returns
#' A named list. The elements are:
#' * "S": standardized residual matrix.
#' * "tot": grand total of the original matrix.
#' * "rowm": row masses.
#' * "colm": column masses.
#'
#' @md
calc_residuals <- function(mat,
                           residuals = "pearson",
                           clip = FALSE,
                           cutoff = NULL) {
    if (residuals == "pearson") {
        if (is.null(cutoff)) cutoff <- 1

        res <- comp_std_residuals(
            mat = mat,
            clip = TRUE,
            cutoff = cutoff
        )
    } else if (residuals == "freemantukey") {
        if (!is.null(cutoff)) {
            warning("Clipping for freemantukey residuals is not implemented. Argument ignored.")
        }
        res <- comp_ft_residuals(mat)
    } else if (residuals == "NB") {
        res <- comp_NB_residuals(
            mat = mat,
            theta = 100,
            cutoff = cutoff,
            clip = TRUE
        )
    } else {
        stop("Unknown type of residuals.")
    }

    return(res)
}

#' removes 0-only rows and columns in a matrix.
#'
#' @param obj A matrix.
#' @return Input matrix with rows & columns consisting of only 0 removed.
rm_zeros <- function(obj) {
    stopifnot(is(obj, "matrix") | is(obj, "dgeMatrix") | is(obj, "dgCMatrix"))

    no_zeros_rows <- Matrix::rowSums(obj) > 0
    no_zeros_cols <- Matrix::colSums(obj) > 0

    if (sum(!no_zeros_rows) != 0) {
        ## Delete genes with only zero values across all columns
        warning(
            "Matrix contains rows with only 0s. ",
            "These rows were removed. ",
            "If undesired set rm_zeros = FALSE."
        )
        obj <- obj[no_zeros_rows, ]
    }
    if (sum(!no_zeros_cols) != 0) {
        ## Delete cells with only zero values across all genes
        warning(
            "Matrix contains columns with only 0s. ",
            "These columns were removed. ",
            "If undesired set rm_zeros = FALSE."
        )
        obj <- obj[, no_zeros_cols]
    }

    return(obj)
}

#' Find most variable rows
#'
#' @description
#' Calculates the variance of the chi-square component matrix and selects the
#' rows with the highest variance, e.g. 5,000.
#'
#' @return
#' Returns a matrix, which consists of the top variable rows of mat.
#'
#' @param mat A numeric matrix. For sequencing a count matrix,
#' gene expression values with genes in rows and samples/cells in columns.
#' Should contain row and column names.
#' @param top Integer. Number of most variable rows to retain. Default 5000.
#' @inheritParams calc_residuals
#' @param ... Further arguments for `calc_residuals`.
#' @export
#' @examples
#' set.seed(1234)
#'
#' # Simulate counts
#'cnts <- mapply(function(x){rpois(n = 500, lambda = x)},
#'               x = sample(1:20, 50, replace = TRUE))
#' rownames(cnts) <- paste0("gene_", 1:nrow(cnts))
#' colnames(cnts) <- paste0("cell_", 1:ncol(cnts))
#'
#' # Choose top 5000 most variable genes
#' cnts <- var_rows(mat = cnts, top = 5000)
#'
#'
var_rows <- function(mat,
                     residuals = "pearson",
                     top = 5000,
                     ...) {
    res <- calc_residuals(
        mat = mat,
        residuals = residuals,
        ...
    )

    if (top > nrow(mat)) {
        warning(
            "Top is larger than the number of rows in matrix. ",
            "Top was set to nrow(mat)."
        )
    }

    top <- min(nrow(mat), top)
    S <- res$S
    if (residuals == "pearson") {
        S <- res$tot * (S^2) # chi-square components matrix
    }

    variances <- apply(S, 1, var) # row-wise variances
    ix_var <- order(-variances)
    mat <- mat[ix_var[seq_len(top)], ] # choose top rows
    return(mat)
}

#' Find most variable rows
#'
#' @description
#' Calculates the contributing inertia of each row which is defined as sum of squares of pearson residuals and selects the
#' rows with the largested inertias, e.g. 5,000.
#'
#' @param mat A matrix with genes in rows and cells in columns.
#' @param top Number of genes to select.
#' @param ... Further arguments for `comp_std_residuals`
#'
#' @return
#' Returns a matrix, which consists of the top variable rows of mat.
inertia_rows <- function(mat, top = 5000, ...) {
    res <- comp_std_residuals(
        mat = mat,
        ...
    )

    if (top > nrow(mat)) {
        warning(
            "Top is larger than the number of rows in matrix. ",
            "Top was set to nrow(mat)."
        )
    }

    top <- min(nrow(mat), top)

    inertia <- res$S^2
    inertia <- Matrix::rowSums(inertia)
    ix <- order(inertia, decreasing = TRUE)
    mat <- mat[ix[seq_len(top)], ] # choose top rows

    return(mat)
}


#' Internal function for `cacomp`
#'
#' @description
#' `run_cacomp` performs correspondence analysis on a matrix and returns the
#' transformed data.
#'
#' @details
#' The calculation is performed according to the work of Michael Greenacre.
#' When working with large matrices,
#' CA coordinates and
#' principal coordinates should only be computed when needed to save
#' computational time.
#'
#' @return
#' Returns a named list of class "cacomp" with components
#' U, V and D: The results from the SVD.
#' row_masses and col_masses: Row and columns masses.
#' top_rows: How many of the most variable rows/genes were retained for the
#' analysis.
#' tot_inertia, row_inertia and col_inertia: Only if inertia = TRUE. Total,
#' row and column inertia respectively.
#' @references
#' Greenacre, M. Correspondence Analysis in Practice, Third Edition, 2017.
#'
#' @param obj A numeric matrix or Seurat/SingleCellExperiment object. For
#' sequencing a count matrix, gene expression values with genes in rows and
#' samples/cells in columns.
#' Should contain row and column names.
#' @param coords Logical. Indicates whether CA standard coordinates should be
#' calculated.
#' @param python DEPRACTED. A logical value indicating whether to use singular-value
#' decomposition from the python package torch.
#' This implementation dramatically speeds up computation compared to `svd()`
#' in R when calculating the full SVD. This parameter only works when dims==NULL
#' or dims==rank(mat), where caculating a full SVD is demanded.
#' @param princ_coords Integer. Number indicating whether principal
#' coordinates should be calculated for the rows (=1), columns (=2),
#' both (=3) or none (=0).
#' @param dims Integer. Number of CA dimensions to retain. If NULL:
#' (0.2 * min(nrow(A), ncol(A)) - 1 ).
#' @param top Integer. Number of most variable rows to retain.
#' Set NULL to keep all.
#' @param inertia Logical. Whether total, row and column inertias should be
#' calculated and returned.
#' @param rm_zeros Logical. Whether rows & cols containing only 0s should be
#' removed. Keeping zero only rows/cols might lead to unexpected results.
#' @inheritParams calc_residuals
#' @param ... Arguments forwarded to methods.
run_cacomp <- function(obj,
                       coords = TRUE,
                       princ_coords = 3,
                       python = FALSE,
                       dims = 100,
                       top = 5000,
                       inertia = TRUE,
                       rm_zeros = TRUE,
                       residuals = "pearson",
                       cutoff = NULL,
                       clip = FALSE,
                       ...) {
    stopifnot(
        "Input matrix does not have any rownames!" =
            !is.null(rownames(obj))
    )
    stopifnot(
        "Input matrix does not have any colnames!" =
            !is.null(colnames(obj))
    )

    if (python == TRUE) {
        warning(
            "The option `python = TRUE` is deprecated. ",
            "Continuing with base::svd for full SVD or irlba for truncated SVD. ",
            "In the future, please choose the number of dimensions <= 0.2 * min(nrow(mat), ncol(mat)). "
        )
    }
    parameters <- list()

    if (rm_zeros == TRUE) {
        if (top == nrow(obj)) {
            update_top <- TRUE
        } else {
            update_top <- FALSE
        }

        obj <- rm_zeros(obj)
        if (isTRUE(update_top)) top <- nrow(obj)
    }

    # Choose only top # of variable genes
    if (!is.null(top) && top < nrow(obj)) {
        obj <- var_rows(
            mat = obj,
            top = top,
            residuals = residuals,
            clip = clip,
            cutoff = cutoff
        )

        res <- calc_residuals(
            mat = obj,
            residuals = residuals,
            clip = clip,
            cutoff = cutoff
        )
        toptmp <- top
    } else {
        if (top > nrow(obj)) {
            warning("\nParameter top is >nrow(obj) and therefore ignored.")
        } else if (is.null(top) || top == nrow(obj)) {
            # do nothing. just here to allow for else statement.
        } else {
            warning("\nUnusual input for top, argument ignored.")
        }

        res <- calc_residuals(
            mat = obj,
            residuals = residuals,
            clip = clip,
            cutoff = cutoff
        )
        toptmp <- nrow(obj)
    }

    S <- res$S
    rowm <- res$rowm
    colm <- res$colm


    k <- min(dim(S)) - 1

    if (is.null(dims)) {
        dims <- floor(0.2 * k)
        if (dims == 0) dims  <- 2
        message("No dimensions specified. Setting dimensions to: ", dims)
    }
    if (dims > k) {
        warning(paste0(
            "Number of dimensions is larger than the rank of the matrix. ",
            "Reducing number of dimensions to rank of the matrix."
        ))
        dims <- k
    }

    if (isTRUE(dims >= (0.5 * k))) {
        message(
            "Please consider setting the dimensions to a lower value",
            " to speed up the calculation.",
            "\nRecommended dimensionality: << min(nrows, ncols) * 0.2"
        )
    }

    # S <- (diag(1/sqrt(r)))%*%(P-r%*%t(c))%*%(diag(1/sqrt(c)))
    SVD <- RSpectra::svds(S, k = dims)
    SVD <- SVD[c("u", "d", "v")]
    names(SVD) <- c("U", "D", "V")
    SVD$D <- as.vector(SVD$D)
    if (length(SVD$D) > dims) SVD$D <- SVD$D[seq_len(dims)]

    # irlba
    # SVD <- irlba::irlba(S, nv = dims, smallest = FALSE) # eigenvalues in a decreasing order
    # SVD <- SVD[1:3]
    # names(SVD)[1:3] <- c("D", "U", "V")
    # SVD$D <- as.vector(SVD$D)

    ndimv <- ncol(SVD$V)
    ndimu <- ncol(SVD$U)
    ord <- order(SVD$D, decreasing = TRUE)
    SVD$D <- SVD$D[ord]
    SVD$V <- SVD$V[, ord, drop = FALSE]
    SVD$U <- SVD$U[, ord, drop = FALSE]
    names(SVD$D) <- paste0("Dim", seq_len(length(SVD$D)))
    dimnames(SVD$V) <- list(colnames(S), paste0("Dim", seq_len(ndimv)))
    dimnames(SVD$U) <- list(rownames(S), paste0("Dim", seq_len(ndimu)))


    if (inertia == TRUE) {
        # calculate inertia
        SVD$tot_inertia <- sum(SVD$D^2)
        SVD$row_inertia <- Matrix::rowSums(S^2)
        SVD$col_inertia <- Matrix::colSums(S^2)
    }

    SVD$row_masses <- rowm
    SVD$col_masses <- colm


    if (!is.null(dims)) {
        if (dims >= length(SVD$D)) {
            if (dims > length(SVD$D)) {
                warning(
                    "Chosen number of dimensions is larger than the ",
                    "number of dimensions obtained from the singular ",
                    "value decomposition. Argument ignored."
                )
            }
            dims <- length(SVD$D)
        } else {
            dims <- min(dims, length(SVD$D))
            dimseq <- seq(dims)

            # subset to number of dimensions
            SVD$U <- SVD$U[, dimseq]
            SVD$V <- SVD$V[, dimseq]
            SVD$D <- SVD$D[dimseq]
        }
    } else {
        dims <- length(SVD$D)
    }


    SVD$dims <- dims
    SVD$top_rows <- toptmp

    # parameters$top_rows <- toptmp
    # parameters$dims <- dims
    parameters$residuals <- residuals
    parameters$clip <- clip
    parameters$cutoff <- cutoff
    parameters$rm_zeros <- rm_zeros
    parameters$python <- python

    SVD$params <- parameters

    SVD <- do.call(new_cacomp, SVD)

    if (coords == TRUE) {
        # message("Calculating coordinates...")

        SVD <- ca_coords(
            caobj = SVD,
            dims = dims,
            princ_coords = princ_coords,
            princ_only = FALSE
        )
    }

    # check if dimensions with ~zero singular values are selected,
    # in case the dimensions selected are more then rank of matrix
    if (min(SVD@D) <= 1e-7) {
        warning(paste(
            "Too many dimensions are selected!",
            "Number of dimensions should be smaller than rank of matrix!"
        ))
    }


    stopifnot(validObject(SVD))
    return(SVD)
}



#' Correspondance Analysis
#'
#' @description
#' `cacomp` performs correspondence analysis on a matrix or
#' Seurat/SingleCellExperiment object and returns the transformed data.
#'
#' @details
#' The calculation is performed according to the work of Michael Greenacre.
#' Singular value decomposition can be performed either with the base R
#' function `svd` or preferably by the faster pytorch implementation
#' (python = TRUE). When working with large matrices, CA coordinates and
#' principal coordinates should only be computed when needed to save
#' computational time.
#'
#' @return
#' Returns a named list of class "cacomp" with components
#' U, V and D: The results from the SVD.
#' row_masses and col_masses: Row and columns masses.
#' top_rows: How many of the most variable rows were retained for the analysis.
#' tot_inertia, row_inertia and col_inertia: Only if inertia = TRUE.
#' Total, row and column inertia respectively.
#' @references
#' Greenacre, M. Correspondence Analysis in Practice, Third Edition, 2017.

#' @param obj A numeric matrix or Seurat/SingleCellExperiment object.
#' For sequencing a count matrix, gene expression values with genes in rows
#' and samples/cells in columns.
#' Should contain row and column names.
#' @inheritParams run_cacomp
#' @inheritParams calc_residuals
#' @param ... Arguments forwarded to methods.
#' @examples
#' # Simulate scRNAseq data.
#' cnts <- data.frame(cell_1 = rpois(10, 5),
#'                    cell_2 = rpois(10, 10),
#'                    cell_3 = rpois(10, 20))
#' rownames(cnts) <- paste0("gene_", 1:10)
#' cnts <- as.matrix(cnts)
#'
#' # Run correspondence analysis.
#' ca <- cacomp(obj = cnts, princ_coords = 3, top = 5)
#' @export
setGeneric("cacomp", function(obj,
                              coords = TRUE,
                              princ_coords = 3,
                              python = FALSE,
                              dims = NULL,
                              top = 5000,
                              inertia = TRUE,
                              rm_zeros = TRUE,
                              residuals = "pearson",
                              cutoff = NULL,
                              clip = FALSE,
                              ...) {
    standardGeneric("cacomp")
})


#' @rdname cacomp
#' @export
setMethod(
    f = "cacomp",
    signature = (obj = "matrix"),
    function(obj,
             coords = TRUE,
             princ_coords = 3,
             python = FALSE,
             dims = NULL,
             top = 5000,
             inertia = TRUE,
             rm_zeros = TRUE,
             residuals = "pearson",
             cutoff = NULL,
             clip = FALSE,
             ...) {
        caobj <- run_cacomp(
            obj = obj,
            coords = coords,
            princ_coords = princ_coords,
            python = python,
            dims = dims,
            top = top,
            inertia = inertia,
            rm_zeros = rm_zeros,
            residuals = residuals,
            cutoff = cutoff,
            clip = clip,
            ...
        )

        return(caobj)
    }
)

#' @rdname cacomp
#' @export
setMethod(
    f = "cacomp",
    signature = (obj = "dgCMatrix"),
    function(obj,
             coords = TRUE,
             princ_coords = 3,
             python = FALSE,
             dims = NULL,
             top = 5000,
             inertia = TRUE,
             rm_zeros = TRUE,
             residuals = "pearson",
             cutoff = NULL,
             clip = FALSE,
             ...) {
        caobj <- run_cacomp(
            obj = obj,
            coords = coords,
            princ_coords = princ_coords,
            python = python,
            dims = dims,
            top = top,
            inertia = inertia,
            rm_zeros = rm_zeros,
            residuals = residuals,
            cutoff = cutoff,
            clip = clip,
            ...
        )

        return(caobj)
    }
)

#' Correspondance Analysis for Seurat objects
#'
#' @description
#' `cacomp.seurat` performs correspondence analysis on an assay from a Seurat
#' container and stores the standardized coordinates of the columns (= cells)
#' and the principal coordinates of the rows (= genes) as a DimReduc Object in
#' the Seurat container.
#'
#' @return
#' If return_imput = TRUE with Seurat container: Returns input obj of class
#' "Seurat" with a new Dimensional Reduction Object named "CA".
#' Standard coordinates of the cells are saved as embeddings,
#' the principal coordinates of the genes as loadings and
#' the singular values (= square root of principal intertias/eigenvalues)
#' are stored as stdev.
#' To recompute a regular "cacomp" object without rerunning cacomp use
#' `as.cacomp()`.
#' @param assay Character. The assay from which extract the count matrix for
#' SVD, e.g. "RNA" for Seurat objects or "counts"/"logcounts" for
#' SingleCellExperiments.
#' @param slot character. The slot of the Seurat assay. Default "counts".
#' @param return_input Logical. If TRUE returns the input
#' (SingleCellExperiment/Seurat object) with the CA results saved in the
#' reducedDim/DimReduc slot "CA".
#' Otherwise returns a "cacomp". Default FALSE.
#' @param ... Other parameters
#' @rdname cacomp
#' @export
#' @examples
#'
#' ###########
#' # Seurat  #
#' ###########
#' library(SeuratObject)
#' set.seed(1234)
#'
#' # Simulate counts
#' cnts <- mapply(function(x){rpois(n = 500, lambda = x)},
#'                      x = sample(1:20, 50, replace = TRUE))
#' rownames(cnts) <- paste0("gene_", 1:nrow(cnts))
#' colnames(cnts) <- paste0("cell_", 1:ncol(cnts))
#'
#' # Create Seurat object
#' seu <- CreateSeuratObject(counts = cnts)
#'
#' # Run CA and save in dim. reduction slot
#' seu <- cacomp(seu, return_input = TRUE, assay = "RNA", slot = "counts")
#'
#' # Run CA and return cacomp object
#' ca <- cacomp(seu, return_input = FALSE, assay = "RNA", slot = "counts")
setMethod(
    f = "cacomp",
    signature = (obj = "Seurat"),
    function(obj,
             coords = TRUE,
             princ_coords = 3,
             python = FALSE,
             dims = NULL,
             top = 5000,
             inertia = TRUE,
             rm_zeros = TRUE,
             residuals = "pearson",
             cutoff = NULL,
             clip = FALSE,
             ...,
             assay = SeuratObject::DefaultAssay(obj),
             slot = "counts",
             return_input = FALSE) {
        stopifnot("obj doesnt belong to class 'Seurat'" = is(obj, "Seurat"))

        stopifnot("Set coords = TRUE when inputting a Seurat object and return_input = TRUE." = coords == TRUE)


        seu <- SeuratObject::LayerData(object = obj, assay = assay, layer = slot)

        if (!(is(seu, "matrix") | is(seu, "dgCMatrix") | is(seu, "dgeMatrix"))) {
            seu <- as.matrix(seu)
        }

        caobj <- run_cacomp(
            obj = seu,
            coords = coords,
            top = top,
            princ_coords = princ_coords,
            dims = dims,
            python = python,
            rm_zeros = rm_zeros,
            inertia = inertia,
            residuals = residuals,
            cutoff = cutoff,
            clip = clip,
            ...
        )

        if (return_input == TRUE) {
            colnames(caobj@V) <- paste0("DIM_", seq(ncol(caobj@V)))
            colnames(caobj@U) <- paste0("DIM_", seq(ncol(caobj@U)))

            obj[["CA"]] <- SeuratObject::CreateDimReducObject(
                embeddings = caobj@std_coords_cols,
                loadings = caobj@prin_coords_rows,
                stdev = caobj@D,
                key = "Dim_",
                assay = assay,
                misc = caobj@params
            )

            return(obj)
        } else {
            return(caobj)
        }
    }
)


#' Correspondance Analysis for SingleCellExperiment objects
#'
#' @description
#' `cacomp.SingleCellExperiment` performs correspondence analysis on an assay
#' from a SingleCellExperiment and stores the standardized coordinates
#'  of the columns (= cells) and the principal coordinates of the rows
#'  (= genes) as a matrix in the SingleCellExperiment container.
#'
#' @return
#' If return_input =TRUE for SingleCellExperiment input returns a
#' SingleCellExperiment object with a matrix of standardized coordinates of
#' the columns in
#' reducedDim(obj, "CA"). Additionally, the matrix contains the following
#' attributes:
#' "prin_coords_rows": Principal coordinates of the rows.
#' "singval": Singular values. For the explained inertia of each principal
#' axis calculate singval^2.
#' "percInertia": Percent explained inertia of each principal axis.
#' To recompute a regular "cacomp" object from a SingleCellExperiment without
#' rerunning cacomp use `as.cacomp()`.
#' @param assay Character. The assay from which extract the count matrix for
#' SVD, e.g. "RNA" for Seurat objects or "counts"/"logcounts" for
#' SingleCellExperiments.
#' @param return_input Logical. If TRUE returns the input
#' (SingleCellExperiment/Seurat object) with the CA results saved in the
#' reducedDim/DimReduc slot "CA".
#'  Otherwise returns a "cacomp". Default FALSE.
#' @rdname cacomp
#' @export
#' @examples
#'
#' ########################
#' # SingleCellExperiment #
#' ########################
#' library(SingleCellExperiment)
#' set.seed(1234)
#'
#' # Simulate counts
#' cnts <- mapply(function(x){rpois(n = 500, lambda = x)},
#'                x = sample(1:20, 50, replace = TRUE))
#' rownames(cnts) <- paste0("gene_", 1:nrow(cnts))
#' colnames(cnts) <- paste0("cell_", 1:ncol(cnts))
#' logcnts <- log2(cnts + 1)
#'
#' # Create SingleCellExperiment object
#' sce <- SingleCellExperiment(assays=list(counts=cnts, logcounts=logcnts))
#'
#' # run CA and save in dim. reduction slot.
#' sce <- cacomp(sce, return_input = TRUE, assay = "counts") # on counts
#' sce <- cacomp(sce, return_input = TRUE, assay = "logcounts") # on logcounts
#'
#' # run CA and return cacomp object.
#' ca <- cacomp(sce, return_input = FALSE, assay = "counts")
setMethod(
    f = "cacomp",
    signature = (obj = "SingleCellExperiment"),
    function(obj,
             coords = TRUE,
             princ_coords = 3,
             python = FALSE,
             dims = NULL,
             top = 5000,
             inertia = TRUE,
             rm_zeros = TRUE,
             residuals = "pearson",
             cutoff = NULL,
             clip = FALSE,
             ...,
             assay = "counts",
             return_input = FALSE) {
        stopifnot("obj doesnt belong to class 'SingleCellExperiment'" = is(obj, "SingleCellExperiment"))
        stopifnot("Set coords = TRUE when inputting a SingleCellExperiment object and return_input = TRUE." = coords == TRUE)

        mat <- SummarizedExperiment::assay(obj, assay)

        if (!(is(mat, "matrix") | is(mat, "dgCMatrix") | is(mat, "dgeMatrix"))) {
            mat <- as.matrix(mat)
        }

        top <- min(nrow(mat), top)

        caobj <- run_cacomp(
            obj = mat,
            coords = coords,
            top = top,
            princ_coords = princ_coords,
            dims = dims,
            python = python,
            rm_zeros = rm_zeros,
            inertia = inertia,
            residuals = residuals,
            cutoff = cutoff,
            clip = clip,
            ...
        )

        if (return_input == TRUE) {
            prinInertia <- caobj@D^2
            percentInertia <- prinInertia / sum(prinInertia) * 100

            # Saving the results
            ca <- caobj@std_coords_cols
            attr(ca, "prin_coords_rows") <- caobj@prin_coords_rows
            attr(ca, "singval") <- caobj@D
            attr(ca, "percInertia") <- percentInertia
            attr(ca, "params") <- caobj@params

            SingleCellExperiment::reducedDim(obj, "CA") <- ca

            return(obj)
        } else {
            return(caobj)
        }
    }
)


#' Subset dimensions of a caobj
#'
#' @description Subsets the dimensions according to user input.
#'
#' @return Returns caobj.
#'
#' @param caobj A caobj.
#' @param dims Integer. Number of dimensions.
#' @examples
#' # Simulate scRNAseq data.
#' cnts <- data.frame(cell_1 = rpois(10, 5),
#'                    cell_2 = rpois(10, 10),
#'                    cell_3 = rpois(10, 20))
#' rownames(cnts) <- paste0("gene_", 1:10)
#' cnts <- as.matrix(cnts)
#'
#' # Run correspondence analysis.
#' ca <- cacomp(cnts)
#' ca <- subset_dims(ca, 2)
#' @export
subset_dims <- function(caobj, dims) {
    # if (dims == 1) {
    #     stop("Please choose more than 1 dimension.")
    # }

    stopifnot(is(caobj, "cacomp"))

    if (is.null(dims)) {
        return(caobj)
    }

    if (dims > length(caobj@D)) {
        warning(
            "dims is larger than the number of available dimensions.",
            " Argument ignored"
        )
    } else if (dims == length(caobj@D)) {
        caobj@dims <- dims
        return(caobj)
    }

    dims <- min(dims, length(caobj@D))
    caobj@dims <- dims
    dims <- seq(dims)
    caobj@U <- caobj@U[, dims, drop = FALSE]
    caobj@V <- caobj@V[, dims, drop = FALSE]
    caobj@D <- caobj@D[dims]

    if (!is.empty(caobj@std_coords_cols)) {
        caobj@std_coords_cols <- caobj@std_coords_cols[, dims, drop = FALSE]
    }
    if (!is.empty(caobj@prin_coords_cols)) {
        caobj@prin_coords_cols <- caobj@prin_coords_cols[, dims, drop = FALSE]
    }

    if (!is.empty(caobj@std_coords_rows)) {
        caobj@std_coords_rows <- caobj@std_coords_rows[, dims, drop = FALSE]
    }
    if (!is.empty(caobj@prin_coords_rows)) {
        caobj@prin_coords_rows <- caobj@prin_coords_rows[, dims, drop = FALSE]
    }

    stopifnot(validObject(caobj))
    return(caobj)
}


#' Calculate correspondence analysis row and column coordinates.
#'
#' @description `ca_coords` calculates the standardized and principal
#' coordinates of the rows and columns in CA space.
#'
#' @details
#' Takes a "cacomp" object and calculates standardized and principal
#' coordinates for the visualization of CA results in a biplot or
#' to subsequently calculate coordinates in an Association Plot.
#'
#' @return
#' Returns input object with coordinates added.
#' std_coords_rows/std_coords_cols: Standardized coordinates of rows/columns.
#' prin_coords_rows/prin_coords_cols: Principal coordinates of rows/columns.
#'
#' @param caobj A "cacomp" object as outputted from `cacomp()`.
#' @param dims Integer indicating the number of dimensions to use for the
#' calculation of coordinates.
#' All elements of caobj (where applicable) will be reduced to the given
#' number of dimensions. Default NULL (keeps all dimensions).
#' @param princ_only Logical, whether only principal coordinates should be
#' calculated.
#' Or, in other words, whether the standardized coordinates are already
#' calculated and stored in `caobj`. Default `FALSE`.
#' @param princ_coords Integer. Number indicating whether principal
#' coordinates should be calculated for the rows (=1), columns (=2), both (=3)
#' or none (=0).
#' Default 3.
#' @examples
#' # Simulate scRNAseq data.
#' cnts <- data.frame(cell_1 = rpois(10, 5),
#'                    cell_2 = rpois(10, 10),
#'                    cell_3 = rpois(10, 20))
#' rownames(cnts) <- paste0("gene_", 1:10)
#' cnts <- as.matrix(cnts)
#'
#' # Run correspondence analysis.
#' ca <- cacomp(obj = cnts, princ_coords = 1)
#' ca <- ca_coords(ca, princ_coords = 3)
#' @export
ca_coords <- function(caobj, dims = NULL, princ_coords = 3, princ_only = FALSE) {
    stopifnot(is(caobj, "cacomp"))
    stopifnot(dims <= length(caobj@D))

    if (!is.null(dims)) {
        if (dims > length(caobj@D)) {
            warning(
                "Chosen dimensions are larger than the number of ",
                "dimensions obtained from the singular value ",
                "decomposition. Argument ignored."
            )
        }
        caobj <- subset_dims(caobj = caobj, dims = dims)
    }


    if (princ_only == FALSE) {
        # standard coordinates
        if (dims == 1 && !is.null(dims)) {
            caobj@std_coords_rows <- caobj@U / sqrt(caobj@row_masses)
            caobj@std_coords_cols <- caobj@V / sqrt(caobj@col_masses)
        } else {
            caobj@std_coords_rows <- sweep(
                x = caobj@U,
                MARGIN = 1,
                STATS = sqrt(caobj@row_masses),
                FUN = "/"
            )
            caobj@std_coords_cols <- sweep(
                x = caobj@V,
                MARGIN = 1,
                STATS = sqrt(caobj@col_masses),
                FUN = "/"
            )
        }


        # Ensure no NA/Inf after dividing by 0.
        caobj@std_coords_rows[is.na(caobj@std_coords_rows)] <- 0
        caobj@std_coords_cols[is.na(caobj@std_coords_cols)] <- 0
        caobj@std_coords_rows[is.infinite(caobj@std_coords_rows)] <- 0
        caobj@std_coords_cols[is.infinite(caobj@std_coords_cols)] <- 0
    }


    stopifnot(
        "princ_coords must be either 0, 1, 2 or 3" =
            (princ_coords == 0 ||
                princ_coords == 1 ||
                princ_coords == 2 ||
                princ_coords == 3)
    )

    if (princ_coords != 0) {
        stopifnot(!is.empty(caobj@std_coords_rows))
        stopifnot(!is.empty(caobj@std_coords_cols))

        if (princ_coords == 1) {
            # principal coordinates for rows
            if (dims == 1 && !is.null(dims)) {
                caobj@prin_coords_rows <- caobj@std_coords_rows * caobj@D
            } else {
                caobj@prin_coords_rows <- sweep(
                    caobj@std_coords_rows,
                    2,
                    caobj@D,
                    "*"
                )
            }
        } else if (princ_coords == 2) {
            # principal coordinates for columns
            if (dims == 1 && !is.null(dims)) {
                caobj@prin_coords_cols <- caobj@std_coords_cols * caobj@D
            } else {
                caobj@prin_coords_cols <- sweep(
                    caobj@std_coords_cols,
                    2,
                    caobj@D,
                    "*"
                )
            }
        } else if (princ_coords == 3) {
            if (dims == 1 && !is.null(dims)) {
                # principal coordinates for rows
                caobj@prin_coords_rows <- caobj@std_coords_rows * caobj@D
                # principal coordinates for columns
                caobj@prin_coords_cols <- caobj@std_coords_cols * caobj@D
            } else {
                # principal coordinates for rows
                caobj@prin_coords_rows <- sweep(
                    caobj@std_coords_rows,
                    2,
                    caobj@D,
                    "*"
                )
                # principal coordinates for columns
                caobj@prin_coords_cols <- sweep(
                    caobj@std_coords_cols,
                    2,
                    caobj@D,
                    "*"
                )
            }
        }
    }

    stopifnot(validObject(caobj))
    return(caobj)
}


#' Scree Plot
#'
#'@description Plots a scree plot.
#'
#'@return
#'Returns a ggplot object.
#'
#'@param df A data frame with columns "dims" and "inertia".
scree_plot <- function(df) {
    stopifnot(c("dims", "inertia") %in% colnames(df))

    avg_inertia <- 100 / nrow(df)
    max_num_dims <- nrow(df)

    screeplot <- ggplot2::ggplot(df, ggplot2::aes(
        x = .data$dims,
        y = .data$inertia
    )) +
        ggplot2::geom_col(fill = "#4169E1") +
        ggplot2::geom_line(color = "#B22222", size = 1) +
        ggplot2::labs(
            title = "Scree plot of explained inertia per dimensions and the average inertia",
            y = "Explained inertia [%]",
            x = "Dimension"
        ) +
        ggplot2::theme_bw()
    return(screeplot)
}

#' Runs elbow method
#'
#' @description Helper function for pick_dims() to run the elbow method.
#'
#' @param obj A "cacomp" object as outputted from `cacomp()`
#' @param mat A numeric matrix. For sequencing a count matrix, gene expression
#' values with genes in rows and samples/cells in columns.
#' Should contain row and column names.
#' @param reps Integer. Number of permutations to perform when choosing
#' "elbow_rule".
#' @param return_plot TRUE/FALSE. Whether a plot should be returned when
#' choosing "elbow_rule".
#' @param python A logical value indicating whether to use singular value
#' decomposition from the python package torch.
#' This implementation dramatically speeds up computation compared to `svd()`
#' in R.
#' @return
#' `elbow_method` (for `return_plot=TRUE`) returns a list with two elements:
#' "dims" contains the number of dimensions and "plot" a ggplot. if
#' `return_plot=TRUE` it just returns the number of picked dimensions.
#' @references
#' Ciampi, Antonio, González Marcos, Ana and Castejón Limas, Manuel. \cr
#' Correspondence analysis and 2-way clustering. (2005), SORT 29(1).
#'
#' @examples
#'
#' # Get example data from Seurat
#' library(SeuratObject)
#' set.seed(2358)
#' cnts <- as.matrix(SeuratObject::LayerData(pbmc_small,
#'                                           assay = "RNA",
#'                                           layer = "data"))
#' # Run correspondence analysis.
#' ca <- cacomp(obj = cnts)
#'
#' # pick dimensions with the elbow rule. Returns list.
#' pd <- pick_dims(obj = ca,
#'                 mat = cnts,
#'                 method = "elbow_rule",
#'                 return_plot = TRUE,
#'                 reps = 10)
#' pd$plot
#' ca_sub <- subset_dims(ca, dims = pd$dims)
#'
elbow_method <- function(obj,
                         mat,
                         reps,
                         python = FALSE,
                         return_plot = FALSE) {
    ev <- obj@D^2
    expl_inertia <- (ev / sum(ev)) * 100
    max_num_dims <- length(obj@D)

    if (isTRUE(obj@params$rm_zeros)) {
        mat <- rm_zeros(mat)
    }

    matrix_expl_inertia_perm <- matrix(0, nrow = max_num_dims, ncol = reps)

    pb <- txtProgressBar(min = 0, max = reps, style = 3)

    for (k in seq(reps)) {
        # mat <- as.matrix(mat)
        mat_perm <- apply(mat, 2, FUN = sample)
        colnames(mat_perm) <- colnames(mat)
        rownames(mat_perm) <- seq_len(nrow(mat_perm))

        obj_perm <- cacomp(
            obj = mat_perm,
            top = obj@top_rows,
            dims = obj@dims,
            coords = FALSE,
            python = python,
            residuals = obj@params$residuals,
            cutoff = obj@params$cutoff,
            clip = obj@params$clip
        )

        ev_perm <- obj_perm@D^2
        expl_inertia_perm <- (ev_perm / sum(ev_perm)) * 100

        matrix_expl_inertia_perm[, k] <- expl_inertia_perm
        colnames(matrix_expl_inertia_perm) <- paste0("perm", seq_len(reps))

        setTxtProgressBar(pb, k)
    }
    close(pb)


    if (return_plot == TRUE) {
        df <- data.frame(
            dims = seq_len(max_num_dims),
            inertia = expl_inertia
        )

        df <- cbind(df, matrix_expl_inertia_perm)

        screeplot <- scree_plot(df)

        for (k in seq_len(reps)) {
            colnm <- ggplot2::sym(paste0("perm", k))

            screeplot <- screeplot +
                ggplot2::geom_line(
                    data = df, ggplot2::aes(
                        x = .data$dims,
                        y = !!colnm
                    ),
                    color = "black",
                    alpha = 0.8,
                    linetype = 2
                )
        }
    }

    avg_inertia_perm <- rowMeans(matrix_expl_inertia_perm)

    tmp <- as.integer(expl_inertia > avg_inertia_perm)
    if (sum(tmp) == 0 || sum(tmp) == max_num_dims) {
        dim_number <- max_num_dims
    } else if (tmp[1] == 0) {
        stop(
            "Average inertia of the permutated data is above ",
            "the explained inertia of the data in the first dimension. ",
            "Please either try more permutations or a different method."
        )
    } else {
        dim_number <- length(tmp[cumsum(tmp == 0) < 1 & tmp != 0])
    }

    if (return_plot == FALSE) {
        return(dim_number)
    } else {
        return(list("dims" = dim_number, "plot" = screeplot))
    }
}


#' Compute statistics to help choose the number of dimensions
#'
#' @description
#' Allow the user to choose from 4 different methods ("avg_inertia",
#' "maj_inertia", "scree_plot" and "elbow_rule")
#' to estimate the number of dimensions that best represent the data.
#'
#' @details
#' * "avg_inertia" calculates the number of dimensions in which the inertia is
#' above the average inertia.
#' * "maj_inertia" calculates the number of dimensions in which cumulatively
#' explain up to 80% of the total inertia.
#' * "scree_plot" plots a scree plot.
#' * "elbow_rule" formalization of the commonly used elbow rule. Permutes the
#' rows for each column and reruns `cacomp()` for a total of `reps` times.
#' The number of relevant dimensions is obtained from the point where the
#' line for the explained inertia of the permuted data intersects with the
#' actual data.
#'
#' @return
#' For `avg_inertia`, `maj_inertia` and `elbow_rule` (when `return_plot=FALSE`)
#' returns an integer, indicating the suggested number of dimensions to use.
#' * `scree_plot` returns a ggplot object.
#' * `elbow_rule` (for `return_plot=TRUE`) returns a list with two elements:
#' "dims" contains the number of dimensions and "plot" a ggplot.
#'
#' @param obj A "cacomp" object as outputted from `cacomp()`,
#' a "Seurat" object with a "CA" DimReduc object stored,
#' or a "SingleCellExperiment" object with a "CA" dim. reduction stored.
#' @param mat A numeric matrix. For sequencing a count matrix, gene expression
#' values with genes in rows and samples/cells in columns.
#' Should contain row and column names.
#' @param method String. Either "scree_plot", "avg_inertia", "maj_inertia" or
#' "elbow_rule" (see Details section). Default "scree_plot".
#' @param reps Integer. Number of permutations to perform when choosing
#' "elbow_rule". Default 3.
#' @param return_plot TRUE/FALSE. Whether a plot should be returned when
#' choosing "elbow_rule". Default FALSE.
#' @param python DEPRACTED. A logical value indicating whether to use singular value
#' decomposition from the python package torch.
#' This implementation dramatically speeds up computation compared to `svd()`
#' in R.
#' @param ... Arguments forwarded to methods.
#' @examples
#' # Simulate counts
#' cnts <- mapply(function(x){rpois(n = 500, lambda = x)},
#'                x = sample(1:20, 50, replace = TRUE))
#' rownames(cnts) <- paste0("gene_", 1:nrow(cnts))
#' colnames(cnts) <- paste0("cell_", 1:ncol(cnts))
#'
#' # Run correspondence analysis.
#' ca <- cacomp(obj = cnts)
#'
#' # pick dimensions with the elbow rule. Returns list.
#'
#' set.seed(2358)
#' pd <- pick_dims(obj = ca,
#'                 mat = cnts,
#'                 method = "elbow_rule",
#'                 return_plot = TRUE,
#'                 reps = 10)
#' pd$plot
#' ca_sub <- subset_dims(ca, dims = pd$dims)
#'
#' # pick dimensions which explain cumulatively >80% of total inertia.
#' # Returns vector.
#' pd <- pick_dims(obj = ca,
#'                 method = "maj_inertia")
#' ca_sub <- subset_dims(ca, dims = pd)
#' @export
#' @md
setGeneric("pick_dims", function(obj,
                                 mat = NULL,
                                 method = "scree_plot",
                                 reps = 3,
                                 python = FALSE,
                                 return_plot = FALSE,
                                 ...) {
    standardGeneric("pick_dims")
})


#' @rdname pick_dims
#' @export
setMethod(
    f = "pick_dims",
    signature = (obj = "cacomp"),
    function(obj,
             mat = NULL,
             method = "scree_plot",
             reps = 3,
             python = FALSE,
             return_plot = FALSE,
             ...) {
        if (!is(obj, "cacomp")) {
            stop("Not a CA object. Please run cacomp() first!")
        }

        ev <- obj@D^2
        expl_inertia <- (ev / sum(ev)) * 100
        max_num_dims <- length(obj@D)

        if (method == "avg_inertia") {
            # Method 1: Dim's > average inertia
            # percentage of inertia explained by 1 dimension (on average)
            avg_inertia <- 100 / max_num_dims
            # result: number of dimensions, all of which explain more than avg_inertia
            dim_num <- sum(expl_inertia > avg_inertia)
            return(dim_num)
        } else if (method == "maj_inertia") {
            # Method 2: Sum of dim's > 80% of the total inertia
            # the first dimension for which the cumulative sum of inertia (from dim1
            # up to given dimension) is higher than 80%
            dim_num <- min(which(cumsum(expl_inertia) > 80))
            return(dim_num)
        } else if (method == "scree_plot") {
            # Method 3: Graphical representation of explained inertia (scree plot)
            # the user can set the threshold based on the scree plot

            df <- data.frame(
                dims = seq_len(max_num_dims),
                inertia = expl_inertia
            )

            screeplot <- scree_plot(df)

            return(screeplot)
        } else if (method == "elbow_rule") {
            if (is.null(mat)) {
                cat(paste0(
                    "When running method=\"elbow_rule\", ",
                    "please provide the original data matrix (paramater mat) ",
                    "which was earlier submitted to cacomp()!"
                ))
                stop()
            }

            pd <- elbow_method(
                obj = obj,
                mat = mat,
                reps = reps,
                python = python,
                return_plot = return_plot
            )
            return(pd)
        } else {
            cat("Please pick a valid method!")
            stop()
        }
    }
)



#' @param assay Character. The assay from which extract the count matrix for
#' SVD, e.g. "RNA" for Seurat objects or "counts"/"logcounts" for
#' SingleCellExperiments.
#' @param slot Character. Data slot of the Seurat assay.
#' E.g. "data" or "counts". Default "counts".
#'
#' @rdname pick_dims
#' @export
#' @examples
#'
#' ################################
#' # pick_dims for Seurat objects #
#' ################################
#' library(SeuratObject)
#' set.seed(1234)
#'
#' # Simulate counts
#' cnts <- mapply(function(x){rpois(n = 500, lambda = x)},
#'                x = sample(1:20, 50, replace = TRUE))
#' rownames(cnts) <- paste0("gene_", 1:nrow(cnts))
#' colnames(cnts) <- paste0("cell_", 1:ncol(cnts))
#'
#' # Create Seurat object
#' seu <- CreateSeuratObject(counts = cnts)
#'
#' # run CA and save in dim. reduction slot.
#' seu <- cacomp(seu, return_input = TRUE, assay = "RNA", slot = "counts")
#'
#' # pick dimensions
#' pd <- pick_dims(obj = seu,
#'                 method = "maj_inertia",
#'                 assay = "RNA",
#'                 slot = "counts")
setMethod(
    f = "pick_dims",
    signature = (obj = "Seurat"),
    function(obj,
             mat = NULL,
             method = "scree_plot",
             reps = 3,
             python = FALSE,
             return_plot = FALSE,
             ...,
             assay = SeuratObject::DefaultAssay(obj),
             slot = "counts") {
        stopifnot("obj doesn't belong to class 'Seurat'" = is(obj, "Seurat"))

        if (method == "elbow_rule") {
            seu <- SeuratObject::LayerData(object = obj, assay = assay, layer = slot)
            seu <- as.matrix(seu)
        } else {
            seu <- NULL
        }

        if ("CA" %in% SeuratObject::Reductions(obj)) {
            caobj <- as.cacomp(obj, assay = assay)
        } else {
            stop(
                "No 'CA' dimension reduction object found. ",
                "Please run cacomp(seurat_obj, top, coords = FALSE, ",
                "return_input=TRUE) first."
            )
        }

        stopifnot(is(caobj, "cacomp"))

        pick_dims(
            obj = caobj,
            mat = seu,
            method = method,
            reps = reps,
            return_plot = return_plot,
            python = python
        )
    }
)


#' @param assay Character. The assay from which to extract the count matrix
#' for SVD, e.g. "RNA" for Seurat objects or "counts"/"logcounts" for
#' SingleCellExperiments.
#'
#' @rdname pick_dims
#' @export
#' @examples
#'
#' ##############################################
#' # pick_dims for SingleCellExperiment objects #
#' ##############################################
#' library(SingleCellExperiment)
#' set.seed(1234)
#'
#' # Simulate counts
#' cnts <- mapply(function(x){rpois(n = 500, lambda = x)},
#'                x = sample(1:20, 50, replace = TRUE))
#' rownames(cnts) <- paste0("gene_", 1:nrow(cnts))
#' colnames(cnts) <- paste0("cell_", 1:ncol(cnts))
#'
#' # Create SingleCellExperiment object
#' sce <- SingleCellExperiment(assays=list(counts=cnts))
#'
#' # run CA and save in dim. reduction slot.
#' sce <- cacomp(sce, return_input = TRUE, assay = "counts")
#'
#' # pick dimensions
#' pd <- pick_dims(obj = sce,
#'                 method = "maj_inertia",
#'                 assay = "counts")
setMethod(
    f = "pick_dims",
    signature = (obj = "SingleCellExperiment"),
    function(obj,
             mat = NULL,
             method = "scree_plot",
             reps = 3,
             python = FALSE,
             return_plot = FALSE,
             ...,
             assay = "counts") {
        stopifnot(
            "obj doesn't belong to class 'SingleCellExperiment'" =
                is(obj, "SingleCellExperiment")
        )

        if (method == "elbow_rule") {
            mat <- SummarizedExperiment::assay(obj, assay)
        } else {
            mat <- NULL
        }

        if ("CA" %in% SingleCellExperiment::reducedDimNames(obj)) {
            caobj <- as.cacomp(obj, assay = assay)
        } else {
            stop(
                "No 'CA' dim. reduction object found. ",
                "Please run cacomp(sce, top, coords = FALSE, ",
                "return_input=TRUE) first."
            )
        }

        stopifnot(is(caobj, "cacomp"))
        pick_dims(
            obj = caobj,
            mat = mat,
            method = method,
            reps = reps,
            return_plot = return_plot,
            python = python
        )
    }
)
VingronLab/APL documentation built on Nov. 9, 2024, 5:33 p.m.