R/StructureGGplot.R

Defines functions StructureGGplot

Documented in StructureGGplot

#' Struture plot using ggplot2
#'
#' Make the traditional Structure plot of GoM model with ggplot2
#'
#' @param omega Cluster membership probabilities of each sample. Usually a
#'              sample by cluster matrix in the Topic model output.
#'              The cluster weights sum to 1 for each sample.
#' @param annotation data.frame of two columns: sample_id and tissue_label.
#'                  sample_id is a vetor consisting of character type of variable,
#'                  which indicates the unique identifying number of each sample.
#'                  tissue_label is a vector consisting of factor type of variable,
#'                  which indicates the sample phenotype that is to be used in
#'                  sorting and grouping the samples in the Structre plot; for example,
#'                  tissue of origin in making Structure plot of the GTEx samples.
#'                  Default is set to "none for when no phenotype information is used to
#'                  order the sample vectors.
#' @param palette Colors assigned to label the clusters. The first color in the palette
#'                  is assigned to the cluster that is labeled 1 (usually arbitrarily
#'                  assigned during the clustering process). Note: The number of colors
#'                  must be the same or greater than the number of clusters. When
#'                  the number of clusters is greater than the number of colors,
#'                  the clusters that are not assigned a color are filled with white
#'                  in the figure. The recommended choice of color palette is RColorBrewer,
#'                  for instance RColorBrewer::brewer.pal(8, "Accent") or
#'                  RColorBrewwer::brewer.pal(9, "Set1").
#' @param figure_title Title of the plot.
#' @param yaxis_label Axis label for the phenotype used to order the samples,
#'                    for example, tissue type or cell type.
#' @param order_sample Whether to order the samples that are of the same tissue label
#'                      or phenotype lable, that is, having the same label in the
#'                      tissue_label variable. If TRUE, we order samples that are of
#'                      the same phenotype label and sort the samples by membership
#'                      of most representative cluster. If FALSE, we keep
#'                      the order in the data.
#' @param sample_order_decreasing If order_sample=TRUE, then order the sample in
#'                  descending (TRUE) or ascending order.
#' @param sample_order_opts Orders by different choices of clusters in a batch.
#'                          Can take the values 1, 2, 3 or 4 corresponding
#'                          to 4 ordering options. Default equal to 1.
#' @param split_line Control parameters for the line that separates phenotype
#'                  subgroups in the plot.
#' @param axis_tick Control parameters for x-axis and y-axis tick sizes.
#' @param plot_labels If TRUE, the plot the axis labels.
#' @param legend_title_size The size of the title of the Structure Plot
#'                           representation.
#' @param legend_key_size The size of the legend key in Structure plot.
#' @param legend_text_size the size specification of the legend text.
#'
#' @return Plots the Structure plot visualization of the GoM model
#'
#' @examples
#' data("MouseDeng2014.FitGoM")
#'
#' # extract the omega matrix: membership weights of each cell
#' names(MouseDeng2014.FitGoM$clust_6)
#' omega <- MouseDeng2014.FitGoM$clust_6$omega
#' tissue_label <- rownames(omega)
#'
#' # make annotation matrix
#' annotation <- data.frame(
#'   sample_id = paste0("X", c(1:NROW(omega))),
#'   tissue_label = factor(rownames(omega),
#'                      levels = rev( c("zy", "early2cell",
#'                                      "mid2cell", "late2cell",
#'                                      "4cell", "8cell", "16cell",
#'                                      "earlyblast","midblast",
#'                                      "lateblast") ) ) )
#' head(annotation)
#'
#' # setw rownames of omega to be sample ID
#' rownames(omega) <- annotation$sample_id
#'
#' StructureGGplot(omega = omega,
#'                  annotation = annotation,
#'                  palette = RColorBrewer::brewer.pal(8, "Accent"),
#'                  yaxis_label = "development phase",
#'                  order_sample = TRUE,
#'                  axis_tick = list(axis_ticks_length = .1,
#'                                   axis_ticks_lwd_y = .1,
#'                                   axis_ticks_lwd_x = .1,
#'                                   axis_label_size = 7,
#'                                   axis_label_face = "bold"))
#'
#' @import ggplot2
#' @importFrom cowplot ggdraw panel_border plot_grid
#' @import plyr
#' @import grDevices
#' @import reshape2
#' @export

StructureGGplot <- function(omega, annotation = NULL,
                            palette = RColorBrewer::brewer.pal(8, "Accent"),
                            figure_title = "",
                            yaxis_label = "Tissue type",
                            order_sample = TRUE,
                            sample_order_decreasing = TRUE,
                            sample_order_opts = 1,
                            split_line = list(split_lwd = 1,
                                              split_col = "white"),
                            plot_labels = TRUE,
                            axis_tick = list(axis_ticks_length = .1,
                                             axis_ticks_lwd_y = .1,
                                             axis_ticks_lwd_x = .1,
                                             axis_label_size = 3,
                                             axis_label_face = "bold"),
                            legend_title_size = 8,
                            legend_key_size = 0.4,
                            legend_text_size = 5) {

    # check if the number of colors is same as or more than the number of clusters
    if (dim(omega)[2] > length(palette)) {
        stop("Color choices is smaller than the number of clusters!")
    }

    # check if rownames of omega are unique
    if(length(unique(rownames(omega))) != NROW(omega)) {
        stop("omega rownames are not unique!")
    }

    # check the annotation data.frame
    null_annotation <- TRUE
    if (is.data.frame(annotation) & length(annotation)>1 ) { null_annotation <- FALSE }
    if (null_annotation) {
      annotation <- data.frame(
                        sample_id = paste("X", c(1:NROW(omega))),
                        tissue_label = rep("NA", NROW(omega)) )
    } else if (!null_annotation) {
      if (!is.data.frame(annotation))
          stop("annotation must be a data.frame")
      if (!all.equal(colnames(annotation), c("sample_id", "tissue_label")) ) {
          stop("annotation data.frame column names must be sample_id and tissue_label")
      }
      if ( length(unique(annotation$sample_id)) != NROW(omega)) {
          stop("sample_id is not unique")
      }
    }

    df_ord <- do.call(rbind,
                      lapply(1:nlevels(annotation$tissue_label), function(ii) {
                          temp_label <- levels(annotation$tissue_label)[ii]
                          temp_df <- omega[which(annotation$tissue_label == temp_label), , drop=FALSE]

                          is_single_sample <- (nrow(temp_df) == 1)

                          # find the dominant cluster in each sample
                          if ( is_single_sample ) {
                              each_sample_order <- which.max(temp_df)
                          } else {
                              each_sample_order <- apply(temp_df, 1, function(x) return(which.max(x)[1]))
                          }

                          # find the dominant cluster across samples
                          tab_samp_order <- table(each_sample_order)

                          if(sample_order_opts == 1)
                              sample_order <- as.numeric(attr(tab_samp_order, "name")[1])
                          if(sample_order_opts == 2)
                           sample_order <- as.numeric(attr(tab_samp_order, "name")[which.max(tab_samp_order)])
                          if(sample_order_opts == 3)
                              sample_order <- as.numeric(attr(tab_samp_order, "name")[length(tab_samp_order)])
                          if(sample_order_opts == 4)
                              sample_order <- as.numeric(attr(tab_samp_order, "name")[which.min(tab_samp_order)])


                          if (order_sample == TRUE & !is_single_sample) {
                              # reorder the matrix
                              temp_df_ord <- temp_df[order(temp_df[ , sample_order],
                                               decreasing = sample_order_decreasing), ]
                          } else {
                              temp_df_ord <- temp_df
                          }
                          temp_df_ord
                      }) )

    df_mlt <- reshape2::melt(t(df_ord))
    df_mlt <- plyr::rename(df_mlt, replace = c("Var1" = "topic",
                                               "Var2" = "document"))
    df_mlt$document <- factor(df_mlt$document)
    df_mlt$topic <- factor(df_mlt$topic)

    # set blank background
    ggplot2::theme_set(ggplot2::theme_bw(base_size = 12)) +
        ggplot2::theme_update( panel.grid.minor.x = ggplot2::element_blank(),
                               panel.grid.minor.y = ggplot2::element_blank(),
                               panel.grid.major.x = ggplot2::element_blank(),
                               panel.grid.major.y = ggplot2::element_blank() )

    # inflat nubmers to avoid rounding errors
    value_ifl <- 10000

    # number of ticks for the weight axis, including 0 and 1
    ticks_number <- 6

    # set axis tick positions
    tissue_count <- table(droplevels(annotation$tissue_label))
    tissue_count_cumsum <- cumsum(table(droplevels(annotation$tissue_label)))
    tissue_names <- levels(droplevels(annotation$tissue_label))

    # if more than 2 levels in the phenotype of interest
    if (length(tissue_names) > 0) {

    tissue_breaks <- sapply(1:length(tissue_count), function(i) {
        if (i == 1) {
            if (tissue_count[i] == 1) bk <- 1
            if (tissue_count[i] > 1)  bk <- (tissue_count_cumsum[i] - 0)/2
            return(bk)
        }
        if (i > 1) {
            if (tissue_count[i] == 1) bk_interval <- 1
            if (tissue_count[i] > 1 ) {
                bk_interval <- (tissue_count_cumsum[i] - tissue_count_cumsum[i-1])/2 }
            bk <- tissue_count_cumsum[i-1] + bk_interval
            return(bk)
        }
    })
    names(tissue_breaks) <- tissue_names

    # make ggplot
    a <- ggplot2::ggplot(df_mlt,
                         ggplot2::aes(x = df_mlt$document,
                                      y = df_mlt$value*10000,
                                      fill = factor(df_mlt$topic)) ) +
        ggplot2::xlab(yaxis_label) + ggplot2::ylab("") +
        ggplot2::scale_fill_manual(values = palette) +
        ggplot2::theme(legend.position = "right",
                       legend.key.size = ggplot2::unit(legend_key_size, "cm"),
                       legend.text = ggplot2::element_text(size = legend_text_size),
                       ##<-- TBD: center legend title
                       #              legend.title = element_text(hjust = 1),
                       axis.text = ggplot2::element_text(size = axis_tick$axis_label_size,
                                                         face = axis_tick$axis_label_face),
                       axis.ticks.y = ggplot2::element_line(size = axis_tick$axis_ticks_lwd_y),
                       axis.ticks.x = ggplot2::element_line(size = axis_tick$axis_ticks_lwd_x),
                       axis.ticks.length = ggplot2::unit(axis_tick$axis_ticks_length, "cm"),
                       title = ggplot2::element_text(size = legend_title_size) ) +
        ggplot2::ggtitle(figure_title) +
        ggplot2::scale_y_continuous( breaks = seq(0, value_ifl, length.out = ticks_number),
                                     labels = seq(0, 1, 1/(ticks_number -1 ) ) ) +
        # Add tissue axis labels
        # ggplot2::scale_x_discrete(breaks = as.character(as.numeric(levels(df_mlt$document)[round(tissue_breaks)])),
        #                           labels = names(tissue_breaks)) +
        ggplot2::scale_x_discrete(breaks = as.character((levels(df_mlt$document)[round(tissue_breaks)])),
                                  labels = names(tissue_breaks)) +
        # Add legend title
        ggplot2::labs(fill = "Clusters") +
        ggplot2::coord_flip()


    # width = 1: increase bar width and in turn remove space
    # between bars
    b <- a + ggplot2::geom_bar(stat = "identity",
                               position = "stack",
                               width = 1)
    # sample labels option
    if (plot_labels == TRUE) {
        b
    } else {
        b <- b + theme(axis.text.y = element_blank())
    }

    # remove plot border
    b <- b + cowplot::panel_border(remove = TRUE)

    # Add demarcation
    b <- b + ggplot2::geom_vline(
        xintercept = cumsum(table(droplevels(annotation$tissue_label)))[
            -length(table(droplevels(annotation$tissue_label)))] + .5,
        col = split_line$split_col,
        size = split_line$split_lwd)
    b

    # filename = paste0(output_dir, "structure.png")
    # png(paste0(filename), width = output_width, height = output_height)
    # ggsave(file=paste0(filename))
    # dev.off()
    } else if (null_annotation) {
      # make ggplot
      a <- ggplot2::ggplot(df_mlt,
                           ggplot2::aes(x = df_mlt$document,
                                        y = df_mlt$value*10000,
                                        fill = factor(df_mlt$topic)) ) +
        ggplot2::xlab(yaxis_label) + ggplot2::ylab("") +
        ggplot2::scale_fill_manual(values = palette) +
        ggplot2::theme(legend.position = "right",
                       legend.key.size = ggplot2::unit(legend_key_size, "cm"),
                       legend.text = ggplot2::element_text(size = legend_text_size),
                       ##<-- TBD: center legend title
                       #              legend.title = element_text(hjust = 1),
                       axis.text = ggplot2::element_text(size = axis_tick$axis_label_size,
                                                         face = axis_tick$axis_label_face),
                       axis.ticks.y = ggplot2::element_line(size = axis_tick$axis_ticks_lwd_y),
                       axis.ticks.length = ggplot2::unit(axis_tick$axis_ticks_length, "cm"),
                       title = ggplot2::element_text(size = legend_title_size) ) +
        ggplot2::ggtitle(figure_title) +
        ggplot2::scale_y_continuous( breaks = seq(0, value_ifl, length.out = ticks_number),
                                     labels = seq(0, 1, 1/(ticks_number -1 ) ) ) +
        ggplot2::scale_x_discrete(breaks = NULL) +
        # Add legend title
        ggplot2::labs(fill = "Clusters") +
        ggplot2::coord_flip()

      # width = 1: increase bar width and in turn remove space
      # between bars
      b <- a + ggplot2::geom_bar(stat = "identity",
                                 position = "stack",
                                 width = 1)
      # sample labels option
      if (plot_labels == TRUE) {
        b
      } else {
        b <- b + theme(axis.text.y = element_blank())
      }

      # remove plot border
      b <- b + cowplot::panel_border(remove = TRUE)

      b

      # filename = paste0(output_dir, "structure.png")
      # png(paste0(filename), width = output_width, height = output_height)
      # ggsave(file=paste0(filename))
      # dev.off()

    #   if(!save_structure){
    #     print(b)
    #   }else{
    #     filename = paste0(output_dir, "structure.png")
    #     png(filename, width = output_width, height = output_height)
    #     print(b)
    #     dev.off()
    #   }
    #
     }

    # if (!plot_labels) {
    #     b
    # } else {
    #     b <- cowplot::ggdraw(cowplot::switch_axis_position((b), axis = "y"))
    #     b
    # }
}
kkdey/CountClust documentation built on Jan. 17, 2021, 5:32 p.m.