R/VisualizationFunctions.R

Defines functions barPlotFunction mvplnVisualize

Documented in mvplnVisualize

#' Visualize Clustered Results Via MVPLN
#'
#' A function to visualize data and clustering results obtained
#' from a mixtures of matrix variate Poisson-log normal (MVPLN) model.
#' Provided a matrix of probabilities for the observations belonging
#' to each cluster, a barplot of probabilities is produced.
#'
#' @param dataset A dataset of class matrix and type integer such that
#'    rows correspond to observations and columns correspond to variables.
#' @param plots A character string indicating which plots to be produced.
#'    Options are 'bar' only for now.
#' @param probabilities A matrix of size N x C, such that rows correspond
#'    to N observations and columns correspond to C clusters. Each row
#'    should sum to 1. Default is NA.
#' @param clusterMembershipVector A numeric vector of length nrow(dataset)
#'    containing the cluster membership of each observation as generated by
#'    mpln(). Default is NA.
#' @param printPlot Logical indicating if plot(s) should be saved in local
#'    directory. Default TRUE. Options TRUE or FALSE.
#' @param fileName Unique character string indicating the name for the plot
#'    being generated. Default is Plot_date, where date is obtained from
#'    date().
#' @param format Character string indicating the format of the image to
#'    be produced. Default 'pdf'. Options 'pdf' or 'png'.
#'
#' @return Plotting function provides the possibility for a bar plot.
#'
#' @examples
#' \dontrun{
#' # Generating simulated matrix variate count data
#' set.seed(1234)
#' trueG <- 2 # number of total G
#' truer <- 2 # number of total occasions
#' truep <- 3 # number of total responses
#' trueN <- 100 # number of total units
#'
#' # Mu is a r x p matrix
#' trueM1 <- matrix(rep(6, (truer * truep)),
#'                  ncol = truep,
#'                  nrow = truer, byrow = TRUE)
#'
#' trueM2 <- matrix(rep(1, (truer * truep)),
#'                  ncol = truep,
#'                  nrow = truer,
#'                  byrow = TRUE)
#'
#' trueMall <- rbind(trueM1, trueM2)
#'
#' # Phi is a r x r matrix
#' # Loading needed packages for generating data
#' # if (!require(clusterGeneration)) install.packages("clusterGeneration")
#' # library("clusterGeneration")
#'
#' # Covariance matrix containing variances and covariances between r occasions
#' # truePhi1 <- clusterGeneration::genPositiveDefMat("unifcorrmat",
#' #                                                   dim = truer,
#' #                                                   rangeVar = c(1, 1.7))$Sigma
#' truePhi1 <- matrix(c(1.075551, -0.488301, -0.488301, 1.362777), nrow = 2)
#' truePhi1[1, 1] <- 1 # For identifiability issues
#'
#' # truePhi2 <- clusterGeneration::genPositiveDefMat("unifcorrmat",
#' #                                                   dim = truer,
#' #                                                   rangeVar = c(0.7, 0.7))$Sigma
#' truePhi2 <- matrix(c(0.7000000, 0.6585887, 0.6585887, 0.7000000), nrow = 2)
#' truePhi2[1, 1] <- 1 # For identifiability issues
#' truePhiall <- rbind(truePhi1, truePhi2)
#'
#' # Omega is a p x p matrix
#' # Covariance matrix containing variances and covariances between p responses
#' # trueOmega1 <- clusterGeneration::genPositiveDefMat("unifcorrmat", dim = truep,
#' #                                    rangeVar = c(1, 1.7))$Sigma
#' trueOmega1 <- matrix(c(1.0526554, 1.0841910, -0.7976842,
#'                        1.0841910,  1.1518811, -0.8068102,
#'                        -0.7976842, -0.8068102,  1.4090578),
#'                        nrow = 3)
#' # trueOmega2 <- clusterGeneration::genPositiveDefMat("unifcorrmat", dim = truep,
#' #                                    rangeVar = c(0.7, 0.7))$Sigma
#' trueOmega2 <- matrix(c(0.7000000, 0.5513744, 0.4441598,
#'                        0.5513744, 0.7000000, 0.4726577,
#'                        0.4441598, 0.4726577, 0.7000000),
#'                        nrow = 3)
#' trueOmegaAll <- rbind(trueOmega1, trueOmega2)
#'
#' # Generated simulated data
#' sampleData <- mixMVPLN::mvplnDataGenerator(nOccasions = truer,
#'                                            nResponses = truep,
#'                                            nUnits = trueN,
#'                                            mixingProportions = c(0.79, 0.21),
#'                                            matrixMean = trueMall,
#'                                            phi = truePhiall,
#'                                            omega = trueOmegaAll)
#'
#' # Clustering simulated matrix variate count data
#' clusteringResults <- mixMVPLN::mvplnMCMCclus(dataset = sampleData$dataset,
#'                                       membership = sampleData$truemembership,
#'                                       gmin = 1,
#'                                       gmax = 2,
#'                                       nChains = 3,
#'                                       nIterations = 300,
#'                                       initMethod = "kmeans",
#'                                       nInitIterations = 1,
#'                                       normalize = "Yes")
#'
#' # Visualize
#' mvplnClustVisuals <- mixMVPLN::mvplnVisualize(
#'   dataset = simulatedMVData$dataset,
#'   plots = 'bar',
#'   probabilities = clusteringResults$allResults[[2]]$allresults$probaPost,
#'   clusterMembershipVector = clusteringResults$allResults[[2]]$allresults$clusterlabels,
#'   fileName = paste0('Plot_', date()),
#'   printPlot = TRUE,
#'   format = 'png')
#' }
#'
#' @author Anjali Silva, \email{anjali@alumni.uoguelph.ca}
#'
#' @references
#' Aitchison, J. and C. H. Ho (1989). The multivariate Poisson-log normal distribution.
#' \emph{Biometrika} 76.
#'
#' Akaike, H. (1973). Information theory and an extension of the maximum likelihood
#' principle. In \emph{Second International Symposium on Information Theory}, New York, NY,
#' USA, pp. 267–281. Springer Verlag.
#'
#' Arlot, S., Brault, V., Baudry, J., Maugis, C., and Michel, B. (2016).
#' capushe: CAlibrating Penalities Using Slope HEuristics. R package version 1.1.1.
#'
#' Biernacki, C., G. Celeux, and G. Govaert (2000). Assessing a mixture model for
#' clustering with the integrated classification likelihood. \emph{IEEE Transactions
#' on Pattern Analysis and Machine Intelligence} 22.
#'
#' Bozdogan, H. (1994). Mixture-model cluster analysis using model selection criteria
#' and a new informational measure of complexity. In \emph{Proceedings of the First US/Japan
#' Conference on the Frontiers of Statistical Modeling: An Informational Approach:
#' Volume 2 Multivariate Statistical Modeling}, pp. 69–113. Dordrecht: Springer Netherlands.
#'
#' Robinson, M.D., and Oshlack, A. (2010). A scaling normalization method for differential
#' expression analysis of RNA-seq data. \emph{Genome Biology} 11, R25.
#'
#' Schwarz, G. (1978). Estimating the dimension of a model. \emph{The Annals of Statistics}
#' 6.
#'
#' Silva, A. et al. (2019). A multivariate Poisson-log normal mixture model
#' for clustering transcriptome sequencing data. \emph{BMC Bioinformatics} 20.
#' \href{https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-2916-0}{Link}
#'
#' Silva, A. et al. (2018). Finite Mixtures of Matrix Variate Poisson-Log Normal Distributions
#' for Three-Way Count Data. \href{https://arxiv.org/abs/1807.08380}{arXiv preprint arXiv:1807.08380}.
#'
#' @export
#' @import graphics
#' @import ggplot2
#' @importFrom grDevices png
#' @importFrom grDevices pdf
#' @importFrom grDevices dev.off
#' @importFrom RColorBrewer brewer.pal.info
#' @importFrom RColorBrewer brewer.pal
#' @importFrom reshape melt
mvplnVisualize <- function(dataset,
                           plots = 'bar',
                           probabilities = NA,
                           clusterMembershipVector = NA,
                           fileName = paste0('Plot_',date()),
                           printPlot = TRUE,
                           format = 'pdf') {

  # Checking user input
  if (is.logical(probabilities) == TRUE) {
    cat("\n Probabilities are not provided. Barplot of probabilities will not be produced.")
  } else if (is.matrix(probabilities) == TRUE) {
    if (nrow(probabilities) != length(clusterMembershipVector)) {
      stop("\n length(probabilities) should match nrow(dataset)")
    }
    if (any(rowSums(probabilities) >= 1.01)) {
      stop("\n rowSums(probabilities) reveals at least
          one observation has probability != 1.")
    }
    if (any(rowSums(probabilities) <= 0.99)) {
      stop("\n rowSums(probabilities) reveals at least
          one observation has probability != 1.")
    }
  }


  # Obtaining path to save images
  pathNow <- getwd()

  # Saving cluster membership for each observation
  DataPlusLabs <- cbind(dataset, clusterMembershipVector)
  ordervector <- anothervector <- list()

  # Divide observations into each cluster based on membership
  for (i in 1:max(clusterMembershipVector)) {
    ordervector[[i]] <- which(DataPlusLabs[,
                                           ncol(dataset) + 1] == i)
    # divide observations as an integer based on cluster membership
    anothervector[[i]] <- rep(i,
                              length(which(DataPlusLabs[,
                                                        ncol(dataset) + 1] == i)))
  }

  vec <- unlist(ordervector) # put observations in order of cluster membership
  colorsvector <- unlist(anothervector) # put all details together as integers

  # Setting the colours
  if(max(clusterMembershipVector) > 17) {
    qualColPals <- RColorBrewer::brewer.pal.info[brewer.pal.info$category == 'qual', ]
    coloursBarPlot <- unlist(mapply(RColorBrewer::brewer.pal,
                                    qualColPals$maxcolors,
                                    rownames(qualColPals)))
  } else {
    coloursBarPlot <- c('#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6',
                        '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324',
                        '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1',
                        '#000075', '#808080')
  }


  # empty plots
  barPlot <- NULL


  if (plots == 'all' || plots == 'bar') {

    if(is.logical(probabilities) == TRUE){
      stop("\n probabilities should be provided to make bar plot.")
    }

    # Bar plot
    tableProbabilities <- as.data.frame(cbind(Sample = c(1:nrow(probabilities)),
                                              Cluster = mclust::map(probabilities),
                                              probabilities))

    names(tableProbabilities) <- c("Sample", "Cluster",
                                   paste0("P", rep(1:(ncol(tableProbabilities)-2))))

    tableProbabilitiesMelt <- reshape::melt(tableProbabilities,
                                            id.vars = c("Sample","Cluster"))

    if (printPlot == TRUE) {
      barPlot <- barPlotFunction(tableProbabilitiesMelt = tableProbabilitiesMelt,
                                 coloursBarPlot = coloursBarPlot,
                                 probabilities = probabilities)
      ggplot2::ggsave(paste0(pathNow,"/barplot_", fileName,".",format))
    }

    barPlot <- barPlotFunction(tableProbabilitiesMelt = tableProbabilitiesMelt,
                               coloursBarPlot = coloursBarPlot,
                               probabilities = probabilities)
  }

  return(barPlot)
}


barPlotFunction <- function(tableProbabilitiesMelt,
                            coloursBarPlot,
                            probabilities) {

  variable <- value <- Sample <- NULL

  if(is.data.frame(tableProbabilitiesMelt) != TRUE) {
    stop("tableProbabilitiesMelt should be a data frame")
  }

  if(is.character(coloursBarPlot) != TRUE) {
    stop("coloursBarPlot should be character")
  }

  if(is.matrix(probabilities) != TRUE) {
    stop("probabilities should be a matrix")
  }

  barPlot <- ggplot2::ggplot(data = tableProbabilitiesMelt,
                             ggplot2::aes(fill = variable, y = value, x = Sample))

  barPlot <- barPlot + ggplot2::geom_bar(position = "fill", stat = "identity") +
    scale_fill_manual(values = coloursBarPlot,
                      name = "Cluster") + theme_bw() +
    theme(text = element_text(size = 10),
          panel.grid.major = element_blank(),
          panel.grid.minor = element_blank(),
          axis.text.x = element_text(face = "bold"),
          axis.text.y = element_text(face = "bold")) +
    coord_cartesian(ylim = c(0, 1), xlim = c(1, nrow(probabilities))) +
    labs(x = "Observation") +
    scale_y_continuous(name = "Posterior probability", limits = c(0: 1))
  return(barPlot)
}

# [END]
anjalisilva/mixMVPLN documentation built on Sept. 24, 2024, 11:05 p.m.